mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] KV Cache Connector API (#7228)
Signed-off-by: jthomson04 <jwillthomson19@gmail.com> Signed-off-by: richardhuo-nv <rihuo@nvidia.com> Co-authored-by: jthomson04 <jwillthomson19@gmail.com> Co-authored-by: Iman Tabrizian <10105175+Tabrizian@users.noreply.github.com> Co-authored-by: Sharan Chetlur <116769508+schetlur-nv@users.noreply.github.com>
This commit is contained in:
parent
085dc19bfa
commit
ce580ce4f5
46
cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h
Normal file
46
cpp/include/tensorrt_llm/batch_manager/kvCacheConnector.h
Normal file
@ -0,0 +1,46 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/common.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using RequestIdType = tensorrt_llm::batch_manager::LlmRequest::RequestIdType;
|
||||
|
||||
/// See tensorrt_llm/_torch/pyexecutor/connector.py for details on the Connector API.
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_connector
|
||||
{
|
||||
|
||||
/// @brief The KV connector manager. This is passed into the C++ KV Cache Manager when adding sequences.
|
||||
class KvCacheConnectorManager
|
||||
{
|
||||
public:
|
||||
KvCacheConnectorManager() = default;
|
||||
virtual ~KvCacheConnectorManager() = default;
|
||||
|
||||
/// @brief Handle the getNumNewMatchedTokens call inside the C++ KV Cache Manager.
|
||||
/// @return The number of tokens that can be loaded from remote KV cache.
|
||||
virtual SizeType32 getNumNewMatchedTokens(LlmRequest const& request, SizeType32 numComputedTokens) = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_connector
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheType.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
|
||||
@ -538,7 +539,8 @@ public:
|
||||
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool,
|
||||
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
|
||||
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse);
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager);
|
||||
|
||||
~WindowBlockManager();
|
||||
|
||||
@ -835,6 +837,8 @@ private:
|
||||
bool mEnablePartialReuse;
|
||||
// Whether partially matched blocks that are already in use should be copied and reused.
|
||||
bool mCopyOnPartialReuse;
|
||||
// The kv cache connector manager
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> mKvCacheConnectorManager;
|
||||
};
|
||||
|
||||
class BlockManager
|
||||
@ -852,7 +856,8 @@ public:
|
||||
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
|
||||
bool copyOnPartialReuse = true);
|
||||
bool copyOnPartialReuse = true,
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
|
||||
|
||||
BlockManager(BlockManager const&) = delete;
|
||||
BlockManager& operator=(BlockManager const&) = delete;
|
||||
@ -1287,6 +1292,7 @@ public:
|
||||
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
|
||||
= 0;
|
||||
|
||||
[[nodiscard]] virtual runtime::ITensor::SharedPtr getUniquePrimaryPool() const = 0;
|
||||
[[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const = 0;
|
||||
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;
|
||||
|
||||
@ -1373,7 +1379,8 @@ public:
|
||||
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
|
||||
bool copyOnpartialReuse = true);
|
||||
bool copyOnpartialReuse = true,
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
|
||||
|
||||
KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
|
||||
@ -1383,7 +1390,8 @@ public:
|
||||
bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
|
||||
bool copyOnpartialReuse = true);
|
||||
bool copyOnpartialReuse = true,
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
|
||||
|
||||
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
|
||||
@ -1393,7 +1401,8 @@ public:
|
||||
bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr, bool enablePartialReuse = true,
|
||||
bool copyOnpartialReuse = true);
|
||||
bool copyOnpartialReuse = true,
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager = nullptr);
|
||||
|
||||
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
|
||||
@ -1624,6 +1633,7 @@ public:
|
||||
std::vector<SizeType32> getNewlyAllocatedBlockIds(
|
||||
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const override;
|
||||
|
||||
runtime::ITensor::SharedPtr getUniquePrimaryPool() const override;
|
||||
runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override;
|
||||
|
||||
SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override
|
||||
|
||||
@ -504,7 +504,8 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
|
||||
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
|
||||
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
|
||||
: mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())}
|
||||
, mTokensPerBlock{tokensPerBlock}
|
||||
, mEventManager{std::move(eventManager)}
|
||||
@ -513,6 +514,10 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
|
||||
{
|
||||
auto const uniqueWindowSizeToLayers
|
||||
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, mNumLayers);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size() == 1,
|
||||
"KV Cache Connector is not supported with multiple window sizes");
|
||||
|
||||
auto const numUniqueWindowSizes = static_cast<SizeType32>(uniqueWindowSizeToLayers.size());
|
||||
|
||||
mIsVariableWindow = numUniqueWindowSizes > 1;
|
||||
@ -530,7 +535,7 @@ BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, Si
|
||||
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
|
||||
sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream,
|
||||
onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse,
|
||||
copyOnPartialReuse);
|
||||
copyOnPartialReuse, kvCacheConnectorManager);
|
||||
}
|
||||
|
||||
auto const numAllPools = getNumPools();
|
||||
@ -572,7 +577,8 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
|
||||
SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
|
||||
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
|
||||
: mDataType{dtype}
|
||||
, mWindowSize{windowSize}
|
||||
, mNumPrimaryBlocks{blocksInPrimaryPool}
|
||||
@ -596,6 +602,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
|
||||
, mTotalInputTokens{0.0}
|
||||
, mEnablePartialReuse{enablePartialReuse}
|
||||
, mCopyOnPartialReuse{copyOnPartialReuse}
|
||||
, mKvCacheConnectorManager{std::move(kvCacheConnectorManager)}
|
||||
{
|
||||
std::map<SizeType32, SizeType32> numLayersPerPool;
|
||||
|
||||
@ -1188,9 +1195,18 @@ void WindowBlockManager::addSequence(
|
||||
auto const prepopulatedPromptLen = loadOrAllocateBlocks(blockKeys, numContextBlocks, sequence, perBlockRetentions);
|
||||
mReusedTokens += static_cast<double>(prepopulatedPromptLen);
|
||||
mTotalInputTokens += static_cast<double>(uniqueTokens.size());
|
||||
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen, getTokensPerBlock());
|
||||
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d", llmRequest.mRequestId,
|
||||
inputLength, prepopulatedPromptLen);
|
||||
|
||||
SizeType32 numConnectorMatchedTokens = 0;
|
||||
|
||||
// If we're using a KV cache connector, check if any additional blocks can be loaded.
|
||||
if (mKvCacheConnectorManager && !llmRequest.isDummyRequest())
|
||||
{
|
||||
numConnectorMatchedTokens = mKvCacheConnectorManager->getNumNewMatchedTokens(llmRequest, prepopulatedPromptLen);
|
||||
}
|
||||
|
||||
llmRequest.setPrepopulatedPromptLen(prepopulatedPromptLen + numConnectorMatchedTokens, getTokensPerBlock());
|
||||
TLLM_LOG_DEBUG("addSequence: Request %lu, inputLength %d, prepopulatedPromptLen %d, numConnectorMatchedTokens %d",
|
||||
llmRequest.mRequestId, inputLength, prepopulatedPromptLen, numConnectorMatchedTokens);
|
||||
}
|
||||
|
||||
// There are two versions of BlockManager::addSequence function.
|
||||
@ -1206,6 +1222,13 @@ void BlockManager::addSequence(
|
||||
void WindowBlockManager::addSequence(
|
||||
GenerationRequest& sequence, SizeType32 numContextBlocks, bool isShareLastContextBlock)
|
||||
{
|
||||
if (mKvCacheConnectorManager)
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"KV Cache Connector specified when block reuse is disabled. The KV Cache Connector will be "
|
||||
"ignored.");
|
||||
}
|
||||
|
||||
auto const requestId = sequence.getRequestId();
|
||||
auto const [seqIt, emplaceDone] = mAllocatedBlocksPerSeq.emplace(requestId, std::vector<BlockPtr>{});
|
||||
TLLM_CHECK(emplaceDone);
|
||||
@ -1618,12 +1641,13 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
|
||||
SizeType32 sinkTokenLength, int64_t stream, std::optional<runtime::SizeType32> maxSequenceLength,
|
||||
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
|
||||
: KVCacheManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth,
|
||||
maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
|
||||
std::make_shared<runtime::CudaStream>(reinterpret_cast<cudaStream_t>(stream)), maxSequenceLength,
|
||||
enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, eventManager, enablePartialReuse,
|
||||
copyOnPartialReuse)
|
||||
copyOnPartialReuse, kvCacheConnectorManager)
|
||||
{
|
||||
}
|
||||
|
||||
@ -1634,7 +1658,8 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
|
||||
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
|
||||
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
|
||||
: mMaxBeamWidth(maxBeamWidth)
|
||||
, mDataType(dtype)
|
||||
, mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end()))
|
||||
@ -1644,7 +1669,7 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
|
||||
, mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
|
||||
std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype,
|
||||
mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager),
|
||||
enablePartialReuse, copyOnPartialReuse)
|
||||
enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
|
||||
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
|
||||
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
|
||||
{
|
||||
@ -1668,11 +1693,12 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size
|
||||
SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<runtime::SizeType32> maxSequenceLength,
|
||||
bool enableBlockReuse, bool onboardBlocks, CacheType cacheType,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse)
|
||||
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
|
||||
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager)
|
||||
: KVCacheManager(std::vector<SizeType32>(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
|
||||
maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength,
|
||||
std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority,
|
||||
std::move(eventManager), enablePartialReuse, copyOnPartialReuse)
|
||||
std::move(eventManager), enablePartialReuse, copyOnPartialReuse, std::move(kvCacheConnectorManager))
|
||||
{
|
||||
}
|
||||
|
||||
@ -2383,6 +2409,13 @@ std::vector<SizeType32> KVCacheManager::getNewlyAllocatedBlockIds(
|
||||
return mBlockManager.getNewlyAllocatedBlockIds(getSequence(requestId), windowSize);
|
||||
}
|
||||
|
||||
runtime::ITensor::SharedPtr KVCacheManager::getUniquePrimaryPool() const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mBlockManager.getWindowSizesMetadata().size() == 1,
|
||||
"getUniquePrimaryPool is only supported for a single window size");
|
||||
return mBlockManager.getPrimaryPool(0);
|
||||
}
|
||||
|
||||
runtime::ITensor::SharedPtr KVCacheManager::getPrimaryPool(SizeType32 layer_idx) const
|
||||
{
|
||||
return mBlockManager.getPrimaryPool(mBlockManager.getLayerPoolIdx(layer_idx));
|
||||
@ -2462,4 +2495,5 @@ SizeType32 KVCacheManager::calculateMaxBlockRequirements(SizeType32 inputLength,
|
||||
auto const leftoverBlockCapacity = blockCapacity - outputBlockRequirements;
|
||||
return std::min(outputLength + leftoverBlockCapacity * tokensPerBlock, inputLength + outputLength);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -7,6 +7,7 @@ set(SRCS
|
||||
batch_manager/algorithms.cpp
|
||||
batch_manager/bindings.cpp
|
||||
batch_manager/cacheTransceiver.cpp
|
||||
batch_manager/kvCacheConnector.cpp
|
||||
batch_manager/kvCacheManager.cpp
|
||||
batch_manager/llmRequest.cpp
|
||||
executor/bindings.cpp
|
||||
|
||||
48
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.cpp
Normal file
48
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.cpp
Normal file
@ -0,0 +1,48 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h"
|
||||
|
||||
#include <nanobind/trampoline.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace
|
||||
{
|
||||
using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager;
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
class PyKvCacheConnectorManager : KvCacheConnectorManager
|
||||
{
|
||||
public:
|
||||
NB_TRAMPOLINE(KvCacheConnectorManager, 1);
|
||||
|
||||
SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override
|
||||
{
|
||||
NB_OVERRIDE_PURE_NAME("get_num_new_matched_tokens", getNumNewMatchedTokens, request, numComputedTokens);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(nb::module_& m)
|
||||
{
|
||||
nb::class_<tb::kv_connector::KvCacheConnectorManager, PyKvCacheConnectorManager>(m, "KvCacheConnectorManager")
|
||||
.def(nb::init<>())
|
||||
.def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens,
|
||||
nb::arg("request"), nb::arg("num_computed_tokens"));
|
||||
}
|
||||
39
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h
Normal file
39
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h
Normal file
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
class KVCacheManagerConnectorBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(nb::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager::kv_connector
|
||||
{
|
||||
|
||||
using namespace tensorrt_llm::batch_manager::kv_connector;
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager::kv_connector
|
||||
@ -39,6 +39,7 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tbc = tensorrt_llm::batch_manager::kv_connector;
|
||||
namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
namespace nb = nanobind;
|
||||
@ -381,6 +382,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
|
||||
auto pool_layer_idx = self.getPoolLayerIdx(layer_idx);
|
||||
return pool.index({torch::indexing::Slice(), pool_layer_idx});
|
||||
})
|
||||
.def("get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); })
|
||||
.def("get_block_offsets_of_batch",
|
||||
[](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize,
|
||||
SizeType32 beamWidth)
|
||||
@ -446,12 +448,13 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
|
||||
.value("SELFKONLY", tbk::CacheType::kSELFKONLY);
|
||||
|
||||
nb::class_<tbk::KVCacheManager, tbk::BaseKVCacheManager>(m, "KVCacheManager")
|
||||
.def(nb::init<std::vector<SizeType32> const&, SizeType32, SizeType32,
|
||||
std::map<SizeType32, std::tuple<SizeType32, SizeType32>> const&, SizeType32, SizeType32,
|
||||
std::vector<SizeType32> const&, std::optional<tbk::TempAttentionWindowInputs> const&,
|
||||
nvinfer1::DataType, SizeType32, int64_t, std::optional<runtime::SizeType32>, bool, bool,
|
||||
tbk::CacheType, std::optional<tensorrt_llm::executor::RetentionPriority>,
|
||||
std::shared_ptr<tbk::KVCacheEventManager>, bool, bool>(),
|
||||
.def(
|
||||
nb::init<std::vector<SizeType32> const&, SizeType32, SizeType32,
|
||||
std::map<SizeType32, std::tuple<SizeType32, SizeType32>> const&, SizeType32, SizeType32,
|
||||
std::vector<SizeType32> const&, std::optional<tbk::TempAttentionWindowInputs> const&,
|
||||
nvinfer1::DataType, SizeType32, int64_t, std::optional<runtime::SizeType32>, bool, bool, tbk::CacheType,
|
||||
std::optional<tensorrt_llm::executor::RetentionPriority>, std::shared_ptr<tbk::KVCacheEventManager>,
|
||||
bool, bool, std::shared_ptr<tbc::KvCacheConnectorManager>>(),
|
||||
nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), nb::arg("tokens_per_block"),
|
||||
nb::arg("blocks_per_window"), nb::arg("max_num_sequences"), nb::arg("max_beam_width"),
|
||||
nb::arg("max_attention_window_vec"), nb::arg("temp_attention_window_inputs").none(), nb::arg("dtype"),
|
||||
@ -459,7 +462,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
|
||||
nb::arg("enable_block_reuse") = false, nb::arg("onboard_blocks") = true,
|
||||
nb::arg("cache_type") = tbk::CacheType::kSELF, nb::arg("secondary_offload_min_priority") = std::nullopt,
|
||||
nb::arg("event_manager") = nullptr, nb::arg("enable_partial_reuse") = true,
|
||||
nb::arg("copy_on_partial_reuse") = true);
|
||||
nb::arg("copy_on_partial_reuse") = true, nb::arg("kv_connector_manager") = nullptr);
|
||||
}
|
||||
|
||||
void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m)
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include "tensorrt_llm/nanobind/batch_manager/algorithms.h"
|
||||
#include "tensorrt_llm/nanobind/batch_manager/bindings.h"
|
||||
#include "tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h"
|
||||
#include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h"
|
||||
#include "tensorrt_llm/nanobind/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/nanobind/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/nanobind/executor/bindings.h"
|
||||
@ -480,6 +481,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
|
||||
tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime);
|
||||
tensorrt_llm::nanobind::testing::initBindings(mInternalTesting);
|
||||
tpb::initBindings(mInternalBatchManager);
|
||||
|
||||
tb::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(mInternalBatchManager);
|
||||
tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager);
|
||||
tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager);
|
||||
tb::CacheTransceiverBindings::initBindings(mInternalBatchManager);
|
||||
|
||||
@ -7,6 +7,7 @@ set(SRCS
|
||||
batch_manager/algorithms.cpp
|
||||
batch_manager/bindings.cpp
|
||||
batch_manager/cacheTransceiver.cpp
|
||||
batch_manager/kvCacheConnector.cpp
|
||||
batch_manager/kvCacheManager.cpp
|
||||
batch_manager/llmRequest.cpp
|
||||
executor/bindings.cpp
|
||||
|
||||
@ -19,6 +19,8 @@
|
||||
|
||||
#include "tensorrt_llm/batch_manager/common.h"
|
||||
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
|
||||
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
|
||||
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
|
||||
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
|
||||
|
||||
47
cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp
Normal file
47
cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.cpp
Normal file
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h"
|
||||
|
||||
namespace
|
||||
{
|
||||
using KvCacheConnectorManager = tensorrt_llm::batch_manager::kv_connector::KvCacheConnectorManager;
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
class PyKvCacheConnectorManager : public KvCacheConnectorManager, py::trampoline_self_life_support
|
||||
{
|
||||
public:
|
||||
using KvCacheConnectorManager::KvCacheConnectorManager;
|
||||
|
||||
SizeType32 getNumNewMatchedTokens(tb::LlmRequest const& request, SizeType32 numComputedTokens) override
|
||||
{
|
||||
PYBIND11_OVERRIDE_PURE_NAME(SizeType32, KvCacheConnectorManager, "get_num_new_matched_tokens",
|
||||
getNumNewMatchedTokens, request, numComputedTokens);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(py::module_& m)
|
||||
{
|
||||
py::class_<tb::kv_connector::KvCacheConnectorManager, PyKvCacheConnectorManager, py::smart_holder>(
|
||||
m, "KvCacheConnectorManager")
|
||||
.def(py::init<>())
|
||||
.def("get_num_new_matched_tokens", &tb::kv_connector::KvCacheConnectorManager::getNumNewMatchedTokens,
|
||||
py::arg("request"), py::arg("num_computed_tokens"));
|
||||
}
|
||||
39
cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h
Normal file
39
cpp/tensorrt_llm/pybind/batch_manager/kvCacheConnector.h
Normal file
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConnector.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
class KVCacheManagerConnectorBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager::kv_connector
|
||||
{
|
||||
|
||||
using namespace tensorrt_llm::batch_manager::kv_connector;
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager::kv_connector
|
||||
@ -30,6 +30,7 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tbc = tensorrt_llm::batch_manager::kv_connector;
|
||||
namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
namespace py = pybind11;
|
||||
@ -214,10 +215,16 @@ public:
|
||||
std::deque<tensorrt_llm::executor::KVCacheEvent>, tbk::BaseKVCacheManager, getLatestEvents, timeout);
|
||||
}
|
||||
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const override
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 poolIdx) const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getPrimaryPool, layer_idx);
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getPrimaryPool, poolIdx);
|
||||
}
|
||||
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr getUniquePrimaryPool() const override
|
||||
{
|
||||
PYBIND11_OVERLOAD_PURE(
|
||||
tensorrt_llm::runtime::ITensor::SharedPtr, tbk::BaseKVCacheManager, getUniquePrimaryPool);
|
||||
}
|
||||
|
||||
SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const override
|
||||
@ -377,6 +384,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
|
||||
auto pool_layer_idx = self.getPoolLayerIdx(layer_idx);
|
||||
return pool.index({torch::indexing::Slice(), pool_layer_idx});
|
||||
})
|
||||
.def("get_unique_primary_pool", [](tbk::BaseKVCacheManager& self) { return self.getUniquePrimaryPool(); })
|
||||
.def("get_block_offsets_of_batch",
|
||||
[](tbk::BaseKVCacheManager& self, at::Tensor output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize,
|
||||
SizeType32 beamWidth)
|
||||
@ -437,7 +445,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
|
||||
std::vector<SizeType32> const&, std::optional<tbk::TempAttentionWindowInputs> const&,
|
||||
nvinfer1::DataType, SizeType32, bool, int64_t, bool, bool, tbk::CacheType,
|
||||
std::optional<tensorrt_llm::executor::RetentionPriority>, std::shared_ptr<tbk::KVCacheEventManager>,
|
||||
bool, bool>(),
|
||||
bool, bool, std::shared_ptr<tbc::KvCacheConnectorManager>>(),
|
||||
py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"), py::arg("tokens_per_block"),
|
||||
py::arg("blocks_per_window"), py::arg("max_num_sequences"), py::arg("max_beam_width"),
|
||||
py::arg("max_attention_window_vec"), py::arg("temp_attention_window_inputs"), py::arg("dtype"),
|
||||
@ -445,7 +453,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
|
||||
py::arg("enable_block_reuse") = false, py::arg("onboard_blocks") = true,
|
||||
py::arg_v("cache_type", tbk::CacheType::kSELF, "bindings.internal.batch_manager.CacheType.SELF"),
|
||||
py::arg("secondary_offload_min_priority") = std::nullopt, py::arg("event_manager") = nullptr,
|
||||
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true);
|
||||
py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true,
|
||||
py::arg("kv_connector_manager") = nullptr);
|
||||
}
|
||||
|
||||
void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
|
||||
|
||||
@ -28,6 +28,7 @@
|
||||
#include "tensorrt_llm/pybind/batch_manager/algorithms.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/bindings.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/cacheTransceiver.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/executor/bindings.h"
|
||||
@ -468,6 +469,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
tensorrt_llm::pybind::runtime::initBindings(mInternalRuntime);
|
||||
tensorrt_llm::pybind::testing::initBindings(mInternalTesting);
|
||||
tpb::initBindings(mInternalBatchManager);
|
||||
tb::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(mInternalBatchManager);
|
||||
tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager);
|
||||
tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager);
|
||||
tb::CacheTransceiverBindings::initBindings(mInternalBatchManager);
|
||||
|
||||
@ -43,15 +43,11 @@ public:
|
||||
.deleter(
|
||||
[ptr = std::move(tensor)](void* data) mutable
|
||||
{
|
||||
try
|
||||
if (data != ptr->data())
|
||||
{
|
||||
TLLM_CHECK(data == ptr->data());
|
||||
ptr.reset();
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_EXCEPTION(e);
|
||||
TLLM_LOG_WARNING("Torch tensor refers to deallocated memory.");
|
||||
}
|
||||
ptr.reset();
|
||||
})
|
||||
.make_tensor();
|
||||
}
|
||||
|
||||
248
examples/llm-api/llm_kv_cache_connector.py
Normal file
248
examples/llm-api/llm_kv_cache_connector.py
Normal file
@ -0,0 +1,248 @@
|
||||
### :title KV Cache Connector
|
||||
### :order 6
|
||||
### :section Customization
|
||||
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import click
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams, logger
|
||||
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
|
||||
KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
|
||||
from tensorrt_llm.bindings.executor import ExecutorConfig
|
||||
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
|
||||
|
||||
# This is a simple example of the use of the KV cache connector.
|
||||
# It persists KV cache contents into a folder, and can load them back on subsequent runs.
|
||||
# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface.
|
||||
# NOTE: This example connector implementation is NOT suitable for production use.
|
||||
|
||||
CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PersistentKvCacheConnectorMetadata:
|
||||
load: list[tuple[str, int]] = field(default_factory=list)
|
||||
save: list[tuple[str, int]] = field(default_factory=list)
|
||||
|
||||
|
||||
class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
|
||||
|
||||
def __init__(self, executor_config: ExecutorConfig):
|
||||
super().__init__(executor_config)
|
||||
|
||||
self.kv_cache_tensor = None
|
||||
|
||||
def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
|
||||
assert self.kv_cache_tensor is None, "KV cache tensor already registered"
|
||||
self.kv_cache_tensor = kv_cache_tensor
|
||||
|
||||
def start_load_kv(self, stream: torch.cuda.Stream):
|
||||
# Do all loads synchronously, and blockwise.
|
||||
for path, block_id in self._metadata.load:
|
||||
cpu_tensor = torch.load(path, map_location="cpu")
|
||||
|
||||
# Copy into the device block.
|
||||
self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
|
||||
|
||||
def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
|
||||
pass
|
||||
|
||||
def wait_for_save(self, stream: torch.cuda.Stream):
|
||||
|
||||
# Make sure the forward pass is complete before beginning our save.
|
||||
stream.synchronize()
|
||||
|
||||
for path, block_id in self._metadata.save:
|
||||
cpu_tensor = self.kv_cache_tensor[block_id].cpu()
|
||||
|
||||
# Don't write anything if this specific block already exists.
|
||||
if Path(path).exists():
|
||||
continue
|
||||
|
||||
# Do a blocking save to the file. This way, we only return once all saves are complete.
|
||||
torch.save(cpu_tensor, path)
|
||||
|
||||
def get_finished(
|
||||
self, finished_gen_req_ids: list[int],
|
||||
started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
|
||||
|
||||
return [], []
|
||||
|
||||
|
||||
class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
|
||||
|
||||
def __init__(self, executor_config: ExecutorConfig):
|
||||
super().__init__(executor_config)
|
||||
|
||||
self.block_size = self._config.tokens_per_block
|
||||
self.pending_loads = {}
|
||||
|
||||
self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
|
||||
"./connector_cache")
|
||||
|
||||
os.makedirs(self.cache_folder, exist_ok=True)
|
||||
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput):
|
||||
# NOTE: This is a simplified implementation, and does not work with chunked prefill.
|
||||
|
||||
metadata = PersistentKvCacheConnectorMetadata()
|
||||
|
||||
for req in scheduler_output.new_requests:
|
||||
# If we don't have any pending loads for this request, we can skip it.
|
||||
if req.request_id not in self.pending_loads:
|
||||
continue
|
||||
|
||||
num_computed_blocks = req.computed_position // self.block_size
|
||||
block_ids = req.new_block_ids
|
||||
|
||||
pending_load = self.pending_loads[req.request_id]
|
||||
|
||||
for file_path, block_pos in zip(
|
||||
pending_load, range(num_computed_blocks, len(block_ids))):
|
||||
metadata.load.append((file_path, block_ids[block_pos]))
|
||||
|
||||
# Break up the remainder of the token sequence into chunks.
|
||||
chunks = self._chunk_tokens(req.new_tokens)
|
||||
|
||||
# For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
|
||||
for block_pos in range(num_computed_blocks + len(pending_load),
|
||||
len(block_ids)):
|
||||
if len(chunks[block_pos]) == self.block_size:
|
||||
hashed_tokens = self._hash_tokens(chunks[block_pos])
|
||||
|
||||
file_path = self._file_path(hashed_tokens)
|
||||
|
||||
metadata.save.append((file_path, block_ids[block_pos]))
|
||||
|
||||
self.pending_loads = {}
|
||||
|
||||
return metadata
|
||||
|
||||
def _hash_tokens(self, tokens: list[int]) -> int:
|
||||
return abs(hash(tuple(tokens)))
|
||||
|
||||
def _file_path(self, hash_value: int) -> Path:
|
||||
return Path(self.cache_folder) / f"{hash_value}.pt"
|
||||
|
||||
def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
|
||||
return [
|
||||
tokens[i:i + self.block_size]
|
||||
for i in range(0, len(tokens), self.block_size)
|
||||
]
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: LlmRequest,
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
self.pending_loads[request.request_id] = []
|
||||
|
||||
# Don't bother with sequences with partial matches.
|
||||
if (num_computed_tokens % self.block_size) != 0:
|
||||
return 0, False
|
||||
|
||||
computed_blocks = num_computed_tokens // self.block_size
|
||||
|
||||
# Get all the tokens that don't have a cache hit on device.
|
||||
remaining_tokens = request.get_tokens(0)[computed_blocks *
|
||||
self.block_size:]
|
||||
|
||||
remaining_chunks = self._chunk_tokens(remaining_tokens)
|
||||
|
||||
# For each chunk, check if it exists in our cache.
|
||||
for chunk in remaining_chunks:
|
||||
# Only do full blocks.
|
||||
if len(chunk) == self.block_size:
|
||||
hashed_tokens = self._hash_tokens(chunk)
|
||||
|
||||
file_path = self._file_path(hashed_tokens)
|
||||
|
||||
# If we get a cache hit, we want to load it into device.
|
||||
# Otherwise, we can stop looking.
|
||||
if file_path.exists():
|
||||
self.pending_loads[request.request_id].append(file_path)
|
||||
else:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
|
||||
)
|
||||
|
||||
return len(
|
||||
self.pending_loads[request.request_id]) * self.block_size, False
|
||||
|
||||
def request_finished(self, request: LlmRequest,
|
||||
cache_block_ids: list[int]) -> bool:
|
||||
# We don't do any asynchronous saving, so always return False
|
||||
return False
|
||||
|
||||
def update_state_after_alloc(self, request: LlmRequest,
|
||||
block_ids: list[int]):
|
||||
pass
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("model", type=str)
|
||||
def main(model: str):
|
||||
sys.path.append(os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"..",
|
||||
))
|
||||
|
||||
this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
|
||||
|
||||
kv_connector_config = KvCacheConnectorConfig(
|
||||
connector_module=this_module,
|
||||
connector_scheduler_class="PersistentKvCacheConnectorLeader",
|
||||
connector_worker_class="PersistentKvCacheConnectorWorker",
|
||||
)
|
||||
|
||||
connector_cache_dir = TemporaryDirectory()
|
||||
os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
|
||||
|
||||
llm = LLM(model=model,
|
||||
backend="pytorch",
|
||||
cuda_graph_config=None,
|
||||
kv_connector_config=kv_connector_config)
|
||||
|
||||
test_text = (
|
||||
"Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
|
||||
"Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
|
||||
"system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
|
||||
"and mobile and automotive applications. Tell me about the company.")
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=32)
|
||||
|
||||
output = llm.generate([test_text], sampling_params)
|
||||
text0 = output[0].outputs[0].text
|
||||
|
||||
print("First output: ", text0)
|
||||
print("Loading new LLM instance...")
|
||||
|
||||
del llm
|
||||
|
||||
llm = LLM(model=model,
|
||||
backend="pytorch",
|
||||
cuda_graph_config=None,
|
||||
kv_connector_config=kv_connector_config)
|
||||
|
||||
output = llm.generate([test_text], sampling_params)
|
||||
text1 = output[0].outputs[0].text
|
||||
|
||||
print("Second output (using connector cache): ", text1)
|
||||
|
||||
assert text0 == text1
|
||||
|
||||
connector_cache_dir.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -22,6 +22,7 @@ from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
|
||||
from .config import PyTorchConfig
|
||||
from .config_utils import is_mla, is_nemotron_hybrid
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .kv_cache_connector import KvCacheConnectorManager
|
||||
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
|
||||
from .llm_request import ExecutorResponse
|
||||
from .mamba_cache_manager import MambaHybridCacheManager
|
||||
@ -44,7 +45,8 @@ class KvCacheCreator:
|
||||
def __init__(self, *, executor_config: ExecutorConfig,
|
||||
model_engine: PyTorchModelEngine,
|
||||
draft_model_engine: Optional[PyTorchModelEngine],
|
||||
mapping: Mapping, net_max_seq_len: int):
|
||||
mapping: Mapping, net_max_seq_len: int,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager]):
|
||||
self._executor_config = executor_config
|
||||
self._model_engine = model_engine
|
||||
self._draft_model_engine = draft_model_engine
|
||||
@ -52,6 +54,7 @@ class KvCacheCreator:
|
||||
self._max_kv_tokens_in = self._executor_config.kv_cache_config.max_tokens
|
||||
self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len -
|
||||
1)
|
||||
self._kv_connector_manager = kv_connector_manager
|
||||
|
||||
@staticmethod
|
||||
def _get_cache_size_per_token(model_config: ModelConfig,
|
||||
@ -335,7 +338,9 @@ class KvCacheCreator:
|
||||
# ---------------------------handle max_gpu_total_bytes---------------------------------
|
||||
|
||||
def _create_kv_cache_manager(
|
||||
self, model_engine: PyTorchModelEngine) -> KVCacheManager:
|
||||
self,
|
||||
model_engine: PyTorchModelEngine,
|
||||
estimating_kv_cache: bool = False) -> KVCacheManager:
|
||||
executor_config = self._executor_config
|
||||
mapping = self._mapping
|
||||
assert model_engine.model.model_config.is_generation, "Only construct KV cache for generation models."
|
||||
@ -377,12 +382,20 @@ class KvCacheCreator:
|
||||
spec_config=spec_config,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
is_draft=model_engine.is_draft_model,
|
||||
kv_connector_manager=self._kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
)
|
||||
elif is_nemotron_hybrid(config):
|
||||
if executor_config.max_beam_width > 1:
|
||||
raise ValueError(
|
||||
"MambaHybridCacheManager + beam search is not supported yet."
|
||||
)
|
||||
|
||||
if not estimating_kv_cache and self._kv_connector_manager is not None:
|
||||
raise NotImplementedError(
|
||||
"Connector manager is not supported for MambaHybridCacheManager."
|
||||
)
|
||||
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
num_layers = config.hybrid_override_pattern.count("*")
|
||||
layer_mask = [
|
||||
@ -443,6 +456,8 @@ class KvCacheCreator:
|
||||
model_config=binding_model_config,
|
||||
max_beam_width=executor_config.max_beam_width,
|
||||
is_draft=model_engine.is_draft_model,
|
||||
kv_connector_manager=self._kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
)
|
||||
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
|
||||
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
|
||||
@ -450,12 +465,21 @@ class KvCacheCreator:
|
||||
|
||||
return kv_cache_manager
|
||||
|
||||
def build_managers(self, resources: Dict) -> None:
|
||||
def build_managers(self,
|
||||
resources: Dict,
|
||||
estimating_kv_cache: bool = False) -> None:
|
||||
"""Construct KV caches for model and draft model (if applicable)."""
|
||||
kv_cache_manager = self._create_kv_cache_manager(self._model_engine)
|
||||
kv_cache_manager = self._create_kv_cache_manager(
|
||||
self._model_engine, estimating_kv_cache)
|
||||
|
||||
if not estimating_kv_cache and self._kv_connector_manager is not None and self._draft_model_engine is not None:
|
||||
raise NotImplementedError(
|
||||
"Connector manager is not supported for draft model.")
|
||||
|
||||
draft_kv_cache_manager = self._create_kv_cache_manager(
|
||||
self._draft_model_engine
|
||||
self._draft_model_engine, estimating_kv_cache
|
||||
) if self._draft_model_engine is not None else None
|
||||
|
||||
resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
|
||||
resources[
|
||||
ResourceManagerType.DRAFT_KV_CACHE_MANAGER] = draft_kv_cache_manager
|
||||
@ -472,20 +496,22 @@ class KvCacheCreator:
|
||||
|
||||
|
||||
def create_py_executor_instance(
|
||||
*,
|
||||
dist,
|
||||
resources,
|
||||
mapping,
|
||||
pytorch_backend_config,
|
||||
executor_config,
|
||||
ctx_chunk_config,
|
||||
model_engine,
|
||||
start_worker,
|
||||
sampler,
|
||||
drafter,
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
|
||||
*,
|
||||
dist,
|
||||
resources,
|
||||
mapping,
|
||||
pytorch_backend_config,
|
||||
executor_config,
|
||||
ctx_chunk_config,
|
||||
model_engine,
|
||||
start_worker,
|
||||
sampler,
|
||||
drafter,
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None
|
||||
) -> PyExecutor:
|
||||
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
|
||||
|
||||
spec_config = model_engine.spec_config
|
||||
@ -632,7 +658,8 @@ def create_py_executor_instance(
|
||||
kv_cache_transceiver=kv_cache_transceiver,
|
||||
guided_decoder=guided_decoder,
|
||||
start_worker=start_worker,
|
||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
|
||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
|
||||
kv_connector_manager=kv_connector_manager)
|
||||
|
||||
|
||||
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
|
||||
|
||||
549
tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py
Normal file
549
tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py
Normal file
@ -0,0 +1,549 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file contains the primary interface for the KV Cache Connector.
|
||||
|
||||
The KV Cache Connector is a component that allows for remote KV cache access.
|
||||
It is responsible for:
|
||||
- Orchestrating the loading and saving of KV cache blocks.
|
||||
- Managing asynchronous block tx/rx.
|
||||
|
||||
It can be used to provide functionalities such as:
|
||||
1. Disagg
|
||||
2. KV offload/onboard
|
||||
3. KV cache sharing
|
||||
4. P2P KV cache transfer
|
||||
etc.
|
||||
|
||||
The Connector API is split into two parts:
|
||||
1. The scheduler, which is responsible for orchestration, and building metadata for the workers.
|
||||
2. The worker, which performs and monitors transfers indicated by the scheduler's metadata.
|
||||
|
||||
To implement a custom KV connector, you need to implement both the scheduler and worker-side interfaces.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
|
||||
Tuple)
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import mpi_allgather, mpi_broadcast, mpi_rank
|
||||
from tensorrt_llm.bindings import LlmRequestState
|
||||
from tensorrt_llm.bindings.executor import ExecutorConfig
|
||||
from tensorrt_llm.bindings.internal.batch_manager import \
|
||||
KvCacheConnectorManager as KvCacheConnectorManagerCpp
|
||||
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
|
||||
|
||||
from .scheduler import ScheduledRequests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .resource_manager import KVCacheManager
|
||||
|
||||
|
||||
# Used to store data for a single inflight request.
|
||||
@dataclass
|
||||
class RequestData:
|
||||
# The request ID.
|
||||
request_id: int
|
||||
# The new tokens that were generated in the prior forward pass.
|
||||
new_tokens: List[int]
|
||||
# The new block IDs allocated in the prior forward pass.
|
||||
new_block_ids: List[int]
|
||||
# The position of the latest token with computed (valid) kv cache values.
|
||||
computed_position: int
|
||||
|
||||
|
||||
# A class to store some basic data regarding all inflight requests.
|
||||
# This is used when calling `build_connector_meta` on the scheduler.
|
||||
@dataclass
|
||||
class SchedulerOutput:
|
||||
# Requests being scheduled for the first time. Requests will show up in `new_request` exactly once.
|
||||
new_requests: List[RequestData] = field(default_factory=list)
|
||||
|
||||
# Requests being scheduled, that have already shown up in `new_requests`.
|
||||
cached_requests: List[RequestData] = field(default_factory=list)
|
||||
|
||||
|
||||
class KvCacheConnectorWorker(ABC):
|
||||
|
||||
def __init__(self, config: ExecutorConfig):
|
||||
self._config = config
|
||||
self._metadata = None
|
||||
super().__init__()
|
||||
|
||||
def bind_connector_meta(self, metadata: object):
|
||||
self._metadata = metadata
|
||||
|
||||
def get_connector_meta(self) -> object:
|
||||
return self._metadata
|
||||
|
||||
def _clear_connector_meta(self):
|
||||
self._metadata = None
|
||||
|
||||
@abstractmethod
|
||||
def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
|
||||
"""
|
||||
Register the KV cache tensors to the worker.
|
||||
This can be used for something like NIXL registration.
|
||||
|
||||
Args:
|
||||
kv_cache_tensor: The contiguous KV cache tensor.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, stream: torch.cuda.Stream):
|
||||
"""
|
||||
Begin loading the KV cache in preparation for the next forward pass.
|
||||
Specific blocks to transfer are indicated by the scheduler's metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
|
||||
"""
|
||||
Wait for a layer to finish being loaded before proceeding with the forward pass on the layer.
|
||||
Note: This function is called immediately before the layer's work is enqueued into the stream.
|
||||
|
||||
Args:
|
||||
layer_idx: The index of the layer to wait for.
|
||||
stream: The stream the forward pass is being executed on.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
|
||||
"""
|
||||
Begin saving the KV cache for a layer.
|
||||
Note: This function is called immediately after the layer's work is enqueued into the stream.
|
||||
|
||||
Args:
|
||||
layer_idx: The index of the layer to save.
|
||||
stream: The stream the forward pass is being executed on.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_save(self, stream: torch.cuda.Stream):
|
||||
"""
|
||||
Block until all synchronous saving operations are complete. Called at the end of the forward pass.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_finished(
|
||||
self, finished_gen_req_ids: List[int],
|
||||
started_loading_req_ids: List[int]) -> Tuple[List[int], List[int]]:
|
||||
"""
|
||||
Get the requests that have finished loading and saving.
|
||||
|
||||
Args:
|
||||
finished_gen_req_ids: The IDs of the requests that have finished generating tokens, and are now asynchronously saving.
|
||||
started_loading_req_ids: The IDs of the requests that have started asynchronously loading.
|
||||
|
||||
Returns:
|
||||
The IDs of the requests that have finished saving.
|
||||
The IDs of the requests that have finished loading.
|
||||
|
||||
Note: IDs may only be returned from this call after they've been provided in the `finished_gen_req_ids` and `started_loading_req_ids` arguments.
|
||||
Additionally, the runtime will only take action based on these returned IDs once they've been returned by ALL workers. This allows some workers to take longer than others to complete the operations.
|
||||
"""
|
||||
|
||||
|
||||
class KvCacheConnectorScheduler(ABC):
|
||||
|
||||
def __init__(self, executor_config: ExecutorConfig):
|
||||
self._config = executor_config
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput):
|
||||
"""
|
||||
Build the metadata for the worker.
|
||||
This is called by the KV Cache Manager when adding a sequence.
|
||||
Args:
|
||||
scheduler_output: The data for all inflight requests.
|
||||
|
||||
Returns:
|
||||
The metadata for the workers.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: LlmRequest,
|
||||
num_computed_tokens: int) -> Tuple[int, bool]:
|
||||
"""
|
||||
Get the number of tokens that can be loaded from remote KV cache.
|
||||
This does not include the tokens already matched on device (indicated by `num_computed_tokens`).
|
||||
|
||||
Args:
|
||||
request: The request to get the number of tokens for.
|
||||
num_computed_tokens: The number of tokens already matched on device.
|
||||
|
||||
Returns:
|
||||
The number of tokens that can be loaded from remote KV cache.
|
||||
Whether the tokens will be loaded asynchronously.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def request_finished(self, request: LlmRequest,
|
||||
cache_block_ids: List[int]) -> bool:
|
||||
"""
|
||||
Called when a request is finished generating tokens.
|
||||
|
||||
Args:
|
||||
request: The request that finished generating tokens.
|
||||
|
||||
Returns:
|
||||
Whether the request is performing asynchronous saving operations.
|
||||
If true, this indicates that the kv cache manager should wait to deallocate the blocks until the saving has completed (determined by `get_finished` on the workers).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_state_after_alloc(self, request: LlmRequest,
|
||||
block_ids: List[int]):
|
||||
"""
|
||||
Called after get_num_new_matched_tokens is called to provide the block ids to the scheduler.
|
||||
|
||||
Args:
|
||||
request: The request that was allocated resources.
|
||||
block_ids: The KV cacheblock IDs that were allocated.
|
||||
"""
|
||||
|
||||
|
||||
# An internal dataclass to handle async saving/loading requests.
|
||||
@dataclass
|
||||
class AsyncRequests:
|
||||
saving: Dict[int, LlmRequest]
|
||||
loading: Dict[int, LlmRequest]
|
||||
|
||||
def add_from(self, other: 'AsyncRequests'):
|
||||
"""
|
||||
Remove requests from the other `AsyncRequests` object, and add them to this one.
|
||||
"""
|
||||
self.saving.update(other.saving)
|
||||
self.loading.update(other.loading)
|
||||
|
||||
other.saving = dict()
|
||||
other.loading = dict()
|
||||
|
||||
def extract_by_id(self, saving_ids: List[int],
|
||||
loading_ids: List[int]) -> 'AsyncRequests':
|
||||
"""
|
||||
Extract the requests with the given IDs from this `AsyncRequests` object.
|
||||
|
||||
Args:
|
||||
saving_ids: The IDs of the requests to extract.
|
||||
loading_ids: The IDs of the requests to extract.
|
||||
"""
|
||||
new_async_requests = AsyncRequests(dict(), dict())
|
||||
|
||||
for req_id in saving_ids:
|
||||
new_async_requests.saving[req_id] = self.saving[req_id]
|
||||
del self.saving[req_id]
|
||||
for req_id in loading_ids:
|
||||
new_async_requests.loading[req_id] = self.loading[req_id]
|
||||
del self.loading[req_id]
|
||||
|
||||
return new_async_requests
|
||||
|
||||
@property
|
||||
def saving_ids(self) -> Set[int]:
|
||||
"""
|
||||
Get the IDs of the requests that are being saved asynchronously.
|
||||
"""
|
||||
return set(self.saving.keys())
|
||||
|
||||
@property
|
||||
def loading_ids(self) -> Set[int]:
|
||||
"""
|
||||
Get the IDs of the requests that are being loaded asynchronously.
|
||||
"""
|
||||
return set(self.loading.keys())
|
||||
|
||||
|
||||
class KvCacheConnectorSchedulerOutputRequest:
|
||||
|
||||
def __init__(self):
|
||||
self.block_ids = []
|
||||
self.tokens = []
|
||||
|
||||
def update_and_build_data(self, req: LlmRequest,
|
||||
kv_cache_manager: "KVCacheManager"):
|
||||
block_ids = kv_cache_manager.get_cache_indices(req)
|
||||
tokens = req.get_tokens(0)
|
||||
|
||||
new_block_ids = block_ids[len(self.block_ids):]
|
||||
new_tokens = tokens[len(self.tokens):]
|
||||
|
||||
self.block_ids.extend(new_block_ids)
|
||||
self.tokens.extend(new_tokens)
|
||||
|
||||
computed_position = len(
|
||||
tokens
|
||||
) - 1 if req.state != LlmRequestState.CONTEXT_INIT and req.state != LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS else req.context_current_position
|
||||
|
||||
return RequestData(req.request_id, new_tokens, new_block_ids,
|
||||
computed_position)
|
||||
|
||||
|
||||
class KvCacheConnectorSchedulerOutputManager:
|
||||
|
||||
def __init__(self):
|
||||
self.requests = defaultdict(KvCacheConnectorSchedulerOutputRequest)
|
||||
self.external_loads = dict()
|
||||
|
||||
def build_scheduler_output(self, scheduled_batch: ScheduledRequests,
|
||||
new_async_requests: AsyncRequests,
|
||||
kv_cache_manager: "KVCacheManager"):
|
||||
scheduler_output = SchedulerOutput()
|
||||
|
||||
for req in scheduled_batch.context_requests:
|
||||
if req.request_id in new_async_requests.loading_ids:
|
||||
continue
|
||||
|
||||
is_new = req.request_id not in self.requests
|
||||
|
||||
request_data = self.requests[req.request_id].update_and_build_data(
|
||||
req, kv_cache_manager)
|
||||
|
||||
# Don't include the connector matched tokens in the initial scheduler output.
|
||||
if req.request_id in self.external_loads:
|
||||
request_data.computed_position -= self.external_loads[
|
||||
req.request_id]
|
||||
|
||||
if is_new:
|
||||
scheduler_output.new_requests.append(request_data)
|
||||
else:
|
||||
scheduler_output.cached_requests.append(request_data)
|
||||
|
||||
for req in scheduled_batch.generation_requests:
|
||||
request_data = self.requests[req.request_id].update_and_build_data(
|
||||
req, kv_cache_manager)
|
||||
|
||||
scheduler_output.cached_requests.append(request_data)
|
||||
|
||||
self.external_loads = dict()
|
||||
|
||||
return scheduler_output
|
||||
|
||||
def record_new_matched_tokens(self, request: LlmRequest,
|
||||
num_new_matched_tokens: int):
|
||||
self.external_loads[request.request_id] = num_new_matched_tokens
|
||||
|
||||
|
||||
class KvCacheConnectorManager(KvCacheConnectorManagerCpp):
|
||||
"""
|
||||
The KvCacheConnectorManager is used to manager connector-related state.
|
||||
|
||||
It has the following responsibilities:
|
||||
1. Managing the state of async requests (both offload and onboard)
|
||||
2. Handling MPI communication. We only run the leader on one rank, but need the results of the leader API on all ranks.
|
||||
|
||||
Note: This class is solely an implementation detail, and is not part of the connector interface itself.
|
||||
When implementing a connector API, you do not need to implement this class.
|
||||
"""
|
||||
|
||||
def __init__(self, worker: KvCacheConnectorWorker,
|
||||
scheduler: Optional[KvCacheConnectorScheduler]):
|
||||
assert (scheduler is not None) == (
|
||||
mpi_rank() == 0), "The scheduler may only exist on rank 0!"
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.worker = worker
|
||||
self.scheduler = scheduler
|
||||
|
||||
# Requests that haven't yet been passed into get_finished.
|
||||
self.new_async_requests = AsyncRequests(dict(), dict())
|
||||
|
||||
# Requests that have been passed into get_finished, but haven't yet been returned.
|
||||
self.pending_async_requests = AsyncRequests(dict(), dict())
|
||||
|
||||
# Requests that have been returned from get_finished locally, but haven't yet been returned by all workers.
|
||||
self.local_finished_async_requests = AsyncRequests(dict(), dict())
|
||||
|
||||
# Requests that have finished loading asynchronously.
|
||||
self.finished_async_loading_requests = dict()
|
||||
|
||||
self._scheduler_output = None
|
||||
self.scheduler_output_manager = KvCacheConnectorSchedulerOutputManager()
|
||||
|
||||
def _run_on_leader(self, f: Callable[[], Any]) -> Any:
|
||||
"""
|
||||
Run a function on the leader rank, and broadcast the result to all other ranks.
|
||||
"""
|
||||
if self.scheduler is not None:
|
||||
assert mpi_rank() == 0, "The scheduler may only exist on rank 0!"
|
||||
res = f()
|
||||
else:
|
||||
res = None
|
||||
return mpi_broadcast(res, root=0)
|
||||
|
||||
def get_num_new_matched_tokens(self, request: LlmRequest,
|
||||
num_computed_tokens: int) -> int:
|
||||
num_tokens, load_kv_async = self._run_on_leader(
|
||||
lambda: self.scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens))
|
||||
|
||||
if num_tokens == 0 and load_kv_async:
|
||||
raise RuntimeError(
|
||||
"load_kv_async must be False when num_tokens is 0!")
|
||||
|
||||
# TODO(jthomson04): This part is a bit ugly.
|
||||
# When the connector indicates that a request will be loaded asynchronously, we need to suspend it's execution.
|
||||
# This is problematic, since at the point when this function is called, the request has already been scheduled!
|
||||
# Because of this, we need to remove it from our list of scheduled requests (see `take_scheduled_requests_pending_load`).
|
||||
if load_kv_async:
|
||||
self.new_async_requests.loading[request.request_id] = request
|
||||
|
||||
self.scheduler_output_manager.record_new_matched_tokens(
|
||||
request, num_tokens)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def should_add_sequence(self, request: LlmRequest) -> bool:
|
||||
req_id = request.request_id
|
||||
return req_id not in self.finished_async_loading_requests
|
||||
|
||||
def build_scheduler_output(self, scheduled_batch: ScheduledRequests,
|
||||
kv_cache_manager: "KVCacheManager"):
|
||||
self._scheduler_output = self.scheduler_output_manager.build_scheduler_output(
|
||||
scheduled_batch, self.new_async_requests, kv_cache_manager)
|
||||
|
||||
def take_scheduled_requests_pending_load(
|
||||
self, scheduled_requests: ScheduledRequests):
|
||||
"""
|
||||
Remove context requests from our list of scheduled requests that are being loaded asynchronously.
|
||||
This is done to prevent the runtime from attempting to load the KV cache for these requests.
|
||||
|
||||
Args:
|
||||
scheduled_requests: The scheduled requests.
|
||||
|
||||
Returns:
|
||||
The scheduled requests with the context requests that are being loaded asynchronously removed.
|
||||
"""
|
||||
allowed_context_requests = []
|
||||
|
||||
for req in scheduled_requests.context_requests:
|
||||
# If this request is being loaded asynchronously, in addition to removing it from the list of scheduled requests,
|
||||
# we also need to update it's state.
|
||||
if req.request_id in self.new_async_requests.loading.keys():
|
||||
req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS
|
||||
|
||||
# Replace the request with the canonical request.
|
||||
self.new_async_requests.loading[req.request_id] = req
|
||||
else:
|
||||
allowed_context_requests.append(req)
|
||||
|
||||
# Update the list of scheduled requests.
|
||||
scheduled_requests.context_requests = allowed_context_requests
|
||||
|
||||
def handle_metadata(self) -> object:
|
||||
metadata = self._run_on_leader(
|
||||
lambda: self.scheduler.build_connector_meta(self._scheduler_output))
|
||||
|
||||
self._scheduler_output = None
|
||||
|
||||
self.worker.bind_connector_meta(metadata)
|
||||
|
||||
def request_finished(self, req: LlmRequest,
|
||||
cache_block_ids: List[int]) -> bool:
|
||||
"""
|
||||
Called when a request is finished generating tokens.
|
||||
|
||||
Args:
|
||||
req: The request that finished generating tokens.
|
||||
|
||||
Returns:
|
||||
Whether the request is performing asynchronous saving operations. If true, we do not immediately call free_resources on the request.
|
||||
"""
|
||||
|
||||
if req.request_id in self.finished_async_loading_requests:
|
||||
del self.finished_async_loading_requests[req.request_id]
|
||||
|
||||
saving_async = self._run_on_leader(
|
||||
lambda: self.scheduler.request_finished(req, cache_block_ids))
|
||||
|
||||
# This is similar to take_scheduled_requests_pending_load.
|
||||
# We need to update the request's state to indicate that it's still being used, but isn't schedulable.
|
||||
if saving_async:
|
||||
self.new_async_requests.saving[req.request_id] = req
|
||||
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS
|
||||
|
||||
return saving_async
|
||||
|
||||
def get_finished(self) -> List[LlmRequest]:
|
||||
"""
|
||||
Process requests that have finished loading and saving.
|
||||
|
||||
Returns:
|
||||
The requests that have newly finished saving.
|
||||
"""
|
||||
started_loading_req_ids = list(self.new_async_requests.loading_ids)
|
||||
finished_gen_req_ids = list(self.new_async_requests.saving_ids)
|
||||
|
||||
# Add the requests to our list of outstanding (still in progress) requests.
|
||||
self.pending_async_requests.add_from(self.new_async_requests)
|
||||
|
||||
# Pass these newly finished requests into get_finished, and get the list of requests that have finished saving and loading.
|
||||
(finished_saving,
|
||||
finished_loading) = self.worker.get_finished(finished_gen_req_ids,
|
||||
started_loading_req_ids)
|
||||
|
||||
# Remove the requests from our pending list that have finished locally.
|
||||
new_local_finished_async_requests = self.pending_async_requests.extract_by_id(
|
||||
finished_saving, finished_loading)
|
||||
|
||||
# Add these requests to our list of locally finished requests.
|
||||
self.local_finished_async_requests.add_from(
|
||||
new_local_finished_async_requests)
|
||||
|
||||
# Broadcast this whole list to all other workers.
|
||||
finished_saving = list(self.local_finished_async_requests.saving_ids)
|
||||
finished_loading = list(self.local_finished_async_requests.loading_ids)
|
||||
|
||||
all_results = mpi_allgather((finished_saving, finished_loading))
|
||||
|
||||
# Find only the requests that have been reported complete by all workers.
|
||||
intersect_finished_saving = set.intersection(
|
||||
*[set(res[0]) for res in all_results])
|
||||
intersect_finished_loading = set.intersection(
|
||||
*[set(res[1]) for res in all_results])
|
||||
|
||||
# Remove these requests from our list of locally finished requests.
|
||||
all_finished = self.local_finished_async_requests.extract_by_id(
|
||||
intersect_finished_saving, intersect_finished_loading)
|
||||
|
||||
# For requests that have finished loading, move them back to the context state.
|
||||
for id, req in all_finished.loading.items():
|
||||
req.state = LlmRequestState.CONTEXT_INIT
|
||||
self.finished_async_loading_requests[id] = req
|
||||
|
||||
# Return the requests that have finished saving.
|
||||
# The execution loop will call _terminate_request on these requests.
|
||||
return list(all_finished.saving.values())
|
||||
|
||||
def update_state_after_alloc(self, req: LlmRequest, block_ids: List[int]):
|
||||
if self.scheduler is not None:
|
||||
self.scheduler.update_state_after_alloc(req, block_ids)
|
||||
|
||||
def set_scheduler_output(self, scheduler_output: SchedulerOutput):
|
||||
self._scheduler_output = scheduler_output
|
||||
|
||||
def layer_pre_hook(self, module, *args):
|
||||
self.worker.wait_for_layer_load(module.layer_idx,
|
||||
torch.cuda.current_stream())
|
||||
|
||||
def layer_post_hook(self, module, *args):
|
||||
self.worker.save_kv_layer(module.layer_idx, torch.cuda.current_stream())
|
||||
@ -36,10 +36,12 @@ from tensorrt_llm.runtime.generation import CUASSERT
|
||||
|
||||
from ..distributed import Distributed
|
||||
from ..models.modeling_utils import DecoderModelForCausalLM
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..speculative.drafter import Drafter
|
||||
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .handle_logits import HandleLogits
|
||||
from .kv_cache_connector import KvCacheConnectorManager
|
||||
from .kv_cache_transceiver import KvCacheTransceiver
|
||||
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||
LlmResponse, get_draft_token_length)
|
||||
@ -137,23 +139,25 @@ class BatchStatePP(BatchState):
|
||||
|
||||
class PyExecutor:
|
||||
|
||||
def __init__(self,
|
||||
resource_manager,
|
||||
scheduler: RequestScheduler,
|
||||
model_engine: ModelEngine,
|
||||
sampler: Sampler,
|
||||
dist: Distributed,
|
||||
max_num_sequences: int,
|
||||
drafter: Optional[Drafter] = None,
|
||||
disable_overlap_scheduler: bool = False,
|
||||
max_input_len: int = 2048,
|
||||
max_batch_size: int = 8,
|
||||
max_beam_width: int = 1,
|
||||
max_draft_len: int = 0,
|
||||
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
start_worker: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
resource_manager,
|
||||
scheduler: RequestScheduler,
|
||||
model_engine: ModelEngine,
|
||||
sampler: Sampler,
|
||||
dist: Distributed,
|
||||
max_num_sequences: int,
|
||||
drafter: Optional[Drafter] = None,
|
||||
disable_overlap_scheduler: bool = False,
|
||||
max_input_len: int = 2048,
|
||||
max_batch_size: int = 8,
|
||||
max_beam_width: int = 1,
|
||||
max_draft_len: int = 0,
|
||||
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
start_worker: bool = True,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None):
|
||||
super(PyExecutor, self).__init__()
|
||||
self.device_id = torch.cuda.current_device()
|
||||
self.global_rank = global_mpi_rank()
|
||||
@ -270,9 +274,42 @@ class PyExecutor:
|
||||
|
||||
self.worker_started = False
|
||||
self.worker_lock = threading.Lock()
|
||||
|
||||
self.kv_connector_manager = kv_connector_manager
|
||||
|
||||
self._maybe_init_kv_connector_manager()
|
||||
|
||||
if start_worker:
|
||||
self.start_worker()
|
||||
|
||||
def _maybe_init_kv_connector_manager(self):
|
||||
if self.kv_connector_manager is not None:
|
||||
if self.kv_cache_transceiver is not None:
|
||||
raise NotImplementedError(
|
||||
"KV Cache Connector is not supported with KvCacheTransceiver."
|
||||
)
|
||||
|
||||
if self.dist.pp_size > 1:
|
||||
raise NotImplementedError(
|
||||
"KV Cache Connector is not supported with pipeline parallelism."
|
||||
)
|
||||
|
||||
if self.kv_cache_manager is None:
|
||||
raise ValueError(
|
||||
"KV Cache Connector requires a KV Cache Manager.")
|
||||
|
||||
kv_tensor = self.kv_cache_manager.get_unique_primary_pool()
|
||||
self.kv_connector_manager.worker.register_kv_caches(kv_tensor)
|
||||
|
||||
# For each of our layers, we need to register the pre/post hooks.
|
||||
# These are used for methods like `wait_for_layer_load` and `save_kv_layer`.
|
||||
for _name, module in self.model_engine.model.named_modules():
|
||||
if isinstance(module, DecoderLayer):
|
||||
module.register_forward_pre_hook(
|
||||
self.kv_connector_manager.layer_pre_hook)
|
||||
module.register_forward_hook(
|
||||
self.kv_connector_manager.layer_post_hook)
|
||||
|
||||
def _event_loop_wrapper(self):
|
||||
try:
|
||||
with customized_gc_thresholds(
|
||||
@ -920,6 +957,25 @@ class PyExecutor:
|
||||
self.guided_decoder.build(scheduled_batch)
|
||||
self.guided_decoder.execute(scheduled_batch, logits)
|
||||
|
||||
def _kv_connector_start_batch(self, scheduled_batch):
|
||||
if self.kv_connector_manager:
|
||||
self.kv_connector_manager.take_scheduled_requests_pending_load(
|
||||
scheduled_batch)
|
||||
self.kv_connector_manager.handle_metadata()
|
||||
self.kv_connector_manager.worker.start_load_kv(
|
||||
torch.cuda.current_stream())
|
||||
|
||||
def _kv_connector_terminate_requests(self):
|
||||
if self.kv_connector_manager:
|
||||
reqs_to_terminate = self.kv_connector_manager.get_finished()
|
||||
for req in reqs_to_terminate:
|
||||
self.resource_manager.free_resources(req)
|
||||
|
||||
def _kv_connector_wait_for_save(self):
|
||||
if self.kv_connector_manager is not None:
|
||||
self.kv_connector_manager.worker.wait_for_save(
|
||||
torch.cuda.current_stream())
|
||||
|
||||
def _executor_loop(self):
|
||||
torch.cuda.set_device(self.device_id)
|
||||
# ensure the context is created, otherwise, some MPI calls will fail.
|
||||
@ -950,12 +1006,17 @@ class PyExecutor:
|
||||
|
||||
# Return the first token to the client
|
||||
self._handle_first_token_response(scheduled_batch)
|
||||
|
||||
self.resource_manager.prepare_resources(scheduled_batch)
|
||||
|
||||
if self.kv_cache_transceiver and self.guided_decoder:
|
||||
self.guided_decoder.init_disagg_gen_requests(
|
||||
scheduled_batch)
|
||||
|
||||
self._kv_connector_start_batch(scheduled_batch)
|
||||
|
||||
if scheduled_batch.batch_size > 0 or (
|
||||
self.enable_attention_dp and self.dist.tp_size > 1):
|
||||
|
||||
if self.drafter is not None and self.use_spec_decode:
|
||||
with request_context(
|
||||
is_draft=True,
|
||||
@ -1001,6 +1062,8 @@ class PyExecutor:
|
||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
|
||||
self._terminate_ctx_finished_requests()
|
||||
|
||||
self._kv_connector_terminate_requests()
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
|
||||
'num_ctx_tokens']
|
||||
@ -1067,9 +1130,12 @@ class PyExecutor:
|
||||
# For generation requests which have completed KV cache transfer
|
||||
self._prepare_disagg_gen_transmission_complete(
|
||||
scheduled_batch)
|
||||
|
||||
self.resource_manager.prepare_resources(scheduled_batch)
|
||||
|
||||
self._kv_connector_start_batch(scheduled_batch)
|
||||
|
||||
if scheduled_batch.batch_size > 0:
|
||||
|
||||
# The generation requests that are do not have batch_idx,
|
||||
# needs to be in front of the batch due to the assumptions
|
||||
# made in model_engine.py::_forward_step. This is only important
|
||||
@ -1125,6 +1191,8 @@ class PyExecutor:
|
||||
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
|
||||
self._terminate_ctx_finished_requests()
|
||||
|
||||
self._kv_connector_terminate_requests()
|
||||
|
||||
def _process_previous_batch(self):
|
||||
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
|
||||
for req in self.previous_batch.ctx_transmission_reqs:
|
||||
@ -1456,6 +1524,9 @@ class PyExecutor:
|
||||
outputs = forward(scheduled_requests, self.resource_manager,
|
||||
new_tensors_device, gather_context_logits,
|
||||
cache_indirection_buffer)
|
||||
|
||||
self._kv_connector_wait_for_save()
|
||||
|
||||
return outputs
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
@ -1579,7 +1650,21 @@ class PyExecutor:
|
||||
self._enqueue_responses(error_responses.items())
|
||||
|
||||
def _terminate_request(self, request: LlmRequest):
|
||||
self.resource_manager.free_resources(request)
|
||||
if self.kv_connector_manager is None:
|
||||
self.resource_manager.free_resources(request)
|
||||
else:
|
||||
# Only call request_finished on the connector if the request has already been added to the kv cache manager.
|
||||
try:
|
||||
cache_block_ids = self.kv_cache_manager.get_cache_indices(
|
||||
request)
|
||||
except IndexError:
|
||||
# If the request has not yet been added to the kv cache manager,
|
||||
# we still need to free resources corresponding to other resource managers.
|
||||
self.resource_manager.free_resources(request)
|
||||
else:
|
||||
if not self.kv_connector_manager.request_finished(
|
||||
request, cache_block_ids):
|
||||
self.resource_manager.free_resources(request)
|
||||
|
||||
@nvtx_range("_handle_canceled_requests")
|
||||
def _handle_canceled_requests(self):
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import copy
|
||||
import enum
|
||||
import importlib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
@ -10,8 +12,11 @@ import torch
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
|
||||
from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy,
|
||||
ContextChunkingPolicy,
|
||||
ExecutorConfig)
|
||||
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -26,6 +31,7 @@ from ._util import (KvCacheCreator, _adjust_torch_mem_fraction,
|
||||
from .config import LoadFormat, PyTorchConfig
|
||||
from .config_utils import is_mla
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .kv_cache_connector import KvCacheConnectorManager
|
||||
from .model_engine import PyTorchModelEngine
|
||||
from .py_executor import PyExecutor
|
||||
|
||||
@ -206,7 +212,9 @@ def create_py_executor(
|
||||
executor_config: ExecutorConfig,
|
||||
checkpoint_dir: str = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
|
||||
garbage_collection_gen0_threshold: Optional[int] = None,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None
|
||||
) -> PyExecutor:
|
||||
_mangle_executor_config(executor_config)
|
||||
pytorch_backend_config = executor_config.pytorch_backend_config
|
||||
|
||||
@ -375,21 +383,70 @@ def create_py_executor(
|
||||
pytorch_backend_config, mapping)
|
||||
logger.info(f"Using Sampler: {type(sampler).__name__}")
|
||||
|
||||
if kv_connector_config is not None:
|
||||
logger.info(
|
||||
f"Initializing kv connector with config: {kv_connector_config}")
|
||||
|
||||
if pytorch_backend_config.use_cuda_graph:
|
||||
raise NotImplementedError(
|
||||
"CUDA graphs are not supported with KV connector hooks.")
|
||||
|
||||
if executor_config.scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT:
|
||||
raise NotImplementedError(
|
||||
"KV connector is only supported with guaranteed no evict scheduler policy."
|
||||
)
|
||||
|
||||
try:
|
||||
module = importlib.import_module(
|
||||
kv_connector_config.connector_module)
|
||||
worker_cls = getattr(module,
|
||||
kv_connector_config.connector_worker_class)
|
||||
scheduler_cls = getattr(
|
||||
module, kv_connector_config.connector_scheduler_class)
|
||||
|
||||
rank = tensorrt_llm.mpi_rank()
|
||||
# Some connector API implementations may need to establish out-of-band communication between the scheduler and workers.
|
||||
# In this case, the worker may be dependent on the scheduler, or vice-versa.
|
||||
# To deal with cases like this, we instantiate them both concurrently.
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
connector_worker_task = executor.submit(worker_cls,
|
||||
executor_config)
|
||||
|
||||
if scheduler_cls is not None and rank == 0:
|
||||
connector_scheduler_task = executor.submit(
|
||||
scheduler_cls, executor_config)
|
||||
connector_scheduler = connector_scheduler_task.result()
|
||||
else:
|
||||
connector_scheduler = None
|
||||
|
||||
connector_worker = connector_worker_task.result()
|
||||
|
||||
kv_connector_manager = KvCacheConnectorManager(
|
||||
connector_worker, connector_scheduler)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error instantiating connector: {e}")
|
||||
raise e
|
||||
else:
|
||||
kv_connector_manager = None
|
||||
|
||||
resources = {}
|
||||
estimating_kv_cache = False
|
||||
kv_cache_creator = None
|
||||
if model_engine.model.model_config.is_generation:
|
||||
#NOTE: non-generation models do not have kv cache
|
||||
kv_cache_creator = KvCacheCreator(executor_config=executor_config,
|
||||
model_engine=model_engine,
|
||||
draft_model_engine=draft_model_engine,
|
||||
mapping=mapping,
|
||||
net_max_seq_len=net_max_seq_len)
|
||||
kv_cache_creator = KvCacheCreator(
|
||||
executor_config=executor_config,
|
||||
model_engine=model_engine,
|
||||
draft_model_engine=draft_model_engine,
|
||||
mapping=mapping,
|
||||
net_max_seq_len=net_max_seq_len,
|
||||
kv_connector_manager=kv_connector_manager)
|
||||
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
|
||||
with mem_monitor.observe_creation_stage(
|
||||
_ExecutorCreationStage.INIT_KV_CACHE
|
||||
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
|
||||
kv_cache_creator.build_managers(resources)
|
||||
kv_cache_creator.build_managers(resources, estimating_kv_cache)
|
||||
|
||||
# Resource managers for speculative decoding
|
||||
# For user-specified drafters, use extra_resource_managers in PyTorchBackend config
|
||||
@ -425,6 +482,8 @@ def create_py_executor(
|
||||
guided_decoder=guided_decoder,
|
||||
lora_config=lora_config,
|
||||
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
|
||||
kv_connector_manager=kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
)
|
||||
|
||||
if estimating_kv_cache:
|
||||
@ -441,7 +500,7 @@ def create_py_executor(
|
||||
# create_kv_cache_manager above, which caps executor_config.max_seq_len. Restoring
|
||||
# the original value before creating the final KV cache.
|
||||
executor_config.max_seq_len = max_seq_len
|
||||
kv_cache_creator.build_managers(resources)
|
||||
kv_cache_creator.build_managers(resources, False)
|
||||
|
||||
for eng in [model_engine, draft_model_engine]:
|
||||
if eng is None:
|
||||
@ -468,6 +527,7 @@ def create_py_executor(
|
||||
lora_config=lora_config,
|
||||
garbage_collection_gen0_threshold=
|
||||
garbage_collection_gen0_threshold,
|
||||
kv_connector_manager=kv_connector_manager,
|
||||
)
|
||||
|
||||
_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)
|
||||
|
||||
@ -17,6 +17,7 @@ from tensorrt_llm.sampling_params import SamplingParams
|
||||
from ..._utils import binding_dtype_size, binding_to_str_dtype, nvtx_range
|
||||
from ...logger import logger
|
||||
from ...mapping import CpType, Mapping
|
||||
from .kv_cache_connector import KvCacheConnectorManager
|
||||
from .llm_request import (LlmRequest, LlmRequestState, SamplingConfig,
|
||||
get_draft_token_length)
|
||||
from .scheduler import ScheduledRequests
|
||||
@ -161,6 +162,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
max_beam_width: int = 1,
|
||||
is_draft: bool = False,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
|
||||
) -> None:
|
||||
self.mapping = mapping
|
||||
self.dtype = dtype
|
||||
@ -178,6 +180,8 @@ class KVCacheManager(BaseResourceManager):
|
||||
for offset, idx in enumerate(self.pp_layers)
|
||||
}
|
||||
|
||||
self.kv_connector_manager = kv_connector_manager
|
||||
|
||||
tp_size = mapping.tp_size
|
||||
if mapping.enable_attention_dp:
|
||||
tp_size = 1
|
||||
@ -329,6 +333,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
'cache_type': kv_cache_type,
|
||||
'enable_partial_reuse': kv_cache_config.enable_partial_reuse,
|
||||
'copy_on_partial_reuse': kv_cache_config.copy_on_partial_reuse,
|
||||
'kv_connector_manager': self.kv_connector_manager,
|
||||
}
|
||||
if self.event_buffer_max_size > 0:
|
||||
if mapping.enable_attention_dp:
|
||||
@ -413,7 +418,8 @@ class KVCacheManager(BaseResourceManager):
|
||||
== self.mapping.cp_size - 1 else 0),
|
||||
req_beam_width, req)
|
||||
else:
|
||||
if req.is_first_context_chunk:
|
||||
if req.is_first_context_chunk and self._kv_connector_should_add_sequence(
|
||||
req):
|
||||
self.impl.add_sequence(req.py_request_id,
|
||||
req.prompt_len, req_beam_width,
|
||||
req)
|
||||
@ -422,11 +428,24 @@ class KVCacheManager(BaseResourceManager):
|
||||
for _ in range(get_draft_token_length(req)):
|
||||
self.impl.add_token(req.py_request_id)
|
||||
|
||||
if self.kv_connector_manager is not None:
|
||||
block_ids = self.get_cache_indices(req)
|
||||
self.kv_connector_manager.update_state_after_alloc(
|
||||
req, block_ids)
|
||||
|
||||
for req in generation_batch:
|
||||
self.impl.add_token(req.py_request_id)
|
||||
for _ in range(get_draft_token_length(req)):
|
||||
self.impl.add_token(req.py_request_id)
|
||||
|
||||
if self.kv_connector_manager is not None:
|
||||
self.kv_connector_manager.build_scheduler_output(
|
||||
scheduled_batch, self)
|
||||
|
||||
def _kv_connector_should_add_sequence(self, request: LlmRequest) -> bool:
|
||||
return self.kv_connector_manager is None or self.kv_connector_manager.should_add_sequence(
|
||||
request)
|
||||
|
||||
def add_dummy_requests(
|
||||
self,
|
||||
request_ids: List[int],
|
||||
@ -626,6 +645,9 @@ class KVCacheManager(BaseResourceManager):
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
def get_unique_primary_pool(self) -> torch.Tensor:
|
||||
return self.impl.get_unique_primary_pool()
|
||||
|
||||
def get_block_ids_per_seq(self, request_ids: List[int]) -> torch.Tensor:
|
||||
block_ids_per_seq = self.get_batch_cache_indices(request_ids)
|
||||
block_ids_per_seq_tensors = [
|
||||
|
||||
@ -21,7 +21,7 @@ from .._utils import mpi_world_size
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import Engine
|
||||
from ..disaggregated_params import DisaggregatedParams
|
||||
from ..llmapi.llm_args import TorchLlmArgs
|
||||
from ..llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
|
||||
from ..llmapi.llm_utils import KvCacheRetentionConfig
|
||||
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
|
||||
need_spawn_mpi_workers)
|
||||
@ -356,6 +356,7 @@ class GenerationExecutor(ABC):
|
||||
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
||||
is_llm_executor: Optional[bool] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[TorchLlmArgs] = None,
|
||||
@ -405,7 +406,8 @@ class GenerationExecutor(ABC):
|
||||
model_world_size=model_world_size,
|
||||
mpi_session=mpi_session,
|
||||
postproc_worker_config=postproc_worker_config,
|
||||
is_llm_executor=is_llm_executor)
|
||||
is_llm_executor=is_llm_executor,
|
||||
kv_connector_config=kv_connector_config)
|
||||
|
||||
# WAR: For the performance of gathering logits, we use single process worker
|
||||
# for TP1 to avoid the large overhead of IPC.
|
||||
@ -415,8 +417,10 @@ class GenerationExecutor(ABC):
|
||||
logger.warning(
|
||||
"Using single process worker for TP1, this may hurt streaming generation performance."
|
||||
)
|
||||
return GenerationExecutorWorker(**worker_kwargs,
|
||||
is_llm_executor=is_llm_executor)
|
||||
return GenerationExecutorWorker(
|
||||
**worker_kwargs,
|
||||
is_llm_executor=is_llm_executor,
|
||||
kv_connector_config=kv_connector_config)
|
||||
|
||||
# For single-gpu case:
|
||||
# Partition the workload to multiple process for streaming performance.
|
||||
@ -428,7 +432,8 @@ class GenerationExecutor(ABC):
|
||||
model_world_size=model_world_size,
|
||||
mpi_session=None, # use mpi4py
|
||||
postproc_worker_config=postproc_worker_config,
|
||||
is_llm_executor=is_llm_executor)
|
||||
is_llm_executor=is_llm_executor,
|
||||
kv_connector_config=kv_connector_config)
|
||||
else:
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
|
||||
@ -439,7 +444,8 @@ class GenerationExecutor(ABC):
|
||||
model_world_size=model_world_size,
|
||||
mpi_session=mpi_session,
|
||||
postproc_worker_config=postproc_worker_config,
|
||||
is_llm_executor=is_llm_executor)
|
||||
is_llm_executor=is_llm_executor,
|
||||
kv_connector_config=kv_connector_config)
|
||||
|
||||
def wait_first_completed(
|
||||
self, futures: List[GenerationResult]
|
||||
|
||||
@ -12,6 +12,7 @@ import zmq.asyncio
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from .._utils import customized_gc_thresholds, mpi_rank, nvtx_range_debug
|
||||
from ..llmapi.llm_args import KvCacheConnectorConfig
|
||||
from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession,
|
||||
RemoteMpiCommSessionClient)
|
||||
from ..llmapi.tracer import enable_llm_tracer, get_tracer, global_tracer
|
||||
@ -45,6 +46,7 @@ class GenerationExecutorProxy(GenerationExecutor):
|
||||
worker_cls: type = GenerationExecutorWorker,
|
||||
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
||||
is_llm_executor: Optional[bool] = None,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
||||
) -> None:
|
||||
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
|
||||
)
|
||||
@ -93,7 +95,8 @@ class GenerationExecutorProxy(GenerationExecutor):
|
||||
worker_kwargs = dict(**worker_kwargs,
|
||||
worker_queues=self._setup_queues(),
|
||||
postproc_worker_config=postproc_worker_config,
|
||||
is_llm_executor=False)
|
||||
is_llm_executor=False,
|
||||
kv_connector_config=kv_connector_config)
|
||||
|
||||
if "log_level" not in worker_kwargs:
|
||||
worker_kwargs["log_level"] = logger.level
|
||||
|
||||
@ -18,7 +18,7 @@ from .._utils import (KVCacheEventSerializer, global_mpi_rank, global_mpi_size,
|
||||
mpi_comm, mpi_rank, nvtx_range_debug)
|
||||
from ..bindings import executor as tllm
|
||||
from ..builder import ConfigEncoder, Engine, EngineConfig
|
||||
from ..llmapi.llm_args import PybindMirror, TorchLlmArgs
|
||||
from ..llmapi.llm_args import KvCacheConnectorConfig, PybindMirror, TorchLlmArgs
|
||||
from ..llmapi.mpi_session import set_mpi_session_cpp
|
||||
from ..llmapi.tokenizer import TokenizerBase
|
||||
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
|
||||
@ -61,6 +61,7 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
|
||||
is_llm_executor: Optional[bool] = None,
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[TorchLlmArgs] = None,
|
||||
@ -87,6 +88,10 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch"
|
||||
self.llm_args = llm_args
|
||||
|
||||
if not self._is_pytorch_backend and kv_connector_config is not None:
|
||||
raise ValueError(
|
||||
"KV connector config is only supported for PyTorch backend")
|
||||
|
||||
if global_mpi_size() > 1:
|
||||
logger.set_rank(self.global_rank)
|
||||
|
||||
@ -127,6 +132,7 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
args["lora_config"] = lora_config
|
||||
args[
|
||||
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
|
||||
args["kv_connector_config"] = kv_connector_config
|
||||
elif executor_config.backend == "_autodeploy":
|
||||
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
|
||||
create_autodeploy_executor
|
||||
@ -680,6 +686,7 @@ def worker_main(
|
||||
is_llm_executor: Optional[
|
||||
bool] = True, # whether it's the main executor instance
|
||||
lora_config: Optional[LoraConfig] = None,
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
|
||||
hf_model_dir: Optional[Path] = None,
|
||||
tokenizer: Optional[TokenizerBase] = None,
|
||||
llm_args: Optional[TorchLlmArgs] = None,
|
||||
@ -809,6 +816,7 @@ def worker_main(
|
||||
postproc_worker_config=postproc_worker_config,
|
||||
is_llm_executor=is_llm_executor,
|
||||
lora_config=lora_config,
|
||||
kv_connector_config=kv_connector_config,
|
||||
hf_model_dir=hf_model_dir,
|
||||
tokenizer=tokenizer,
|
||||
llm_args=llm_args)
|
||||
|
||||
@ -983,6 +983,8 @@ class _TorchLLM(BaseLLM):
|
||||
),
|
||||
is_llm_executor=True,
|
||||
lora_config=self.args.lora_config,
|
||||
# Autodeploy does not support kv_connector_config
|
||||
kv_connector_config=getattr(self.args, "kv_connector_config", None),
|
||||
hf_model_dir=self._hf_model_dir,
|
||||
tokenizer=self.tokenizer,
|
||||
llm_args=self.args)
|
||||
|
||||
@ -404,6 +404,21 @@ class DecodingBaseConfig(StrictBaseModel):
|
||||
self.decoding_type.upper())
|
||||
|
||||
|
||||
class KvCacheConnectorConfig(StrictBaseModel):
|
||||
"""
|
||||
Configuration for the KV Cache Connector.
|
||||
"""
|
||||
connector_module: str = Field(
|
||||
...,
|
||||
description=
|
||||
"The import path to the connector module. It will be imported with `importlib.import_module`."
|
||||
)
|
||||
connector_scheduler_class: str = Field(
|
||||
..., description="The class name of the scheduler within the module.")
|
||||
connector_worker_class: str = Field(
|
||||
..., description="The class name of the worker within the module.")
|
||||
|
||||
|
||||
class MedusaDecodingConfig(DecodingBaseConfig):
|
||||
medusa_choices: Optional[List[List[int]]] = None
|
||||
num_medusa_heads: Optional[int] = None
|
||||
@ -2302,6 +2317,11 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
status="prototype",
|
||||
)
|
||||
|
||||
kv_connector_config: Optional[KvCacheConnectorConfig] = Field(
|
||||
default=None,
|
||||
description="The config for KV cache connector.",
|
||||
)
|
||||
|
||||
mm_encoder_only: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
|
||||
348
tests/integration/defs/llmapi/test_llm_api_connector.py
Normal file
348
tests/integration/defs/llmapi/test_llm_api_connector.py
Normal file
@ -0,0 +1,348 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, KvCacheConnectorConfig
|
||||
|
||||
from ..conftest import llm_models_root
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def model_with_connector():
|
||||
with patch("tensorrt_llm._torch.pyexecutor.py_executor_creator.importlib"
|
||||
) as importlib_mock:
|
||||
mock_scheduler = MagicMock()
|
||||
mock_worker = MagicMock()
|
||||
|
||||
importlib_mock.import_module.return_value.KvConnectorScheduler.return_value = mock_scheduler
|
||||
importlib_mock.import_module.return_value.KvConnectorWorker.return_value = mock_worker
|
||||
|
||||
kv_connector_config = KvCacheConnectorConfig(
|
||||
connector_module="",
|
||||
connector_scheduler_class="KvConnectorScheduler",
|
||||
connector_worker_class="KvConnectorWorker",
|
||||
)
|
||||
|
||||
def model_fn(*args, **kwargs):
|
||||
return LLM(
|
||||
*args,
|
||||
**kwargs,
|
||||
model=f"{llm_models_root()}/Qwen2-0.5B",
|
||||
backend="pytorch",
|
||||
kv_connector_config=kv_connector_config,
|
||||
cuda_graph_config=None,
|
||||
kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.1),
|
||||
)
|
||||
|
||||
yield model_fn, mock_scheduler, mock_worker
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def enforce_single_worker(monkeypatch):
|
||||
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
|
||||
def test_connector_simple(enforce_single_worker, model_with_connector,
|
||||
use_overlap_scheduler):
|
||||
NUM_TOKENS = 8
|
||||
|
||||
model_fn, scheduler, worker = model_with_connector
|
||||
|
||||
model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, )
|
||||
|
||||
assert worker.register_kv_caches.call_count == 1
|
||||
|
||||
scheduler.get_num_new_matched_tokens.return_value = 0, False
|
||||
|
||||
worker.get_finished.return_value = [], []
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True)
|
||||
|
||||
model.generate(["Hello, world"], sampling_params)
|
||||
|
||||
assert scheduler.update_state_after_alloc.call_count == 1
|
||||
|
||||
# Allocate 1 block.
|
||||
assert len(scheduler.update_state_after_alloc.call_args.args[1]) == 1
|
||||
|
||||
# With the overlap scheduler, we generate one extra token.
|
||||
assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int(
|
||||
use_overlap_scheduler)
|
||||
|
||||
# We should have a single `SchedulerOutput` per forward pass.
|
||||
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
|
||||
scheduler_output = call[0][0]
|
||||
if i == 0:
|
||||
assert len(scheduler_output.new_requests) == 1
|
||||
assert len(scheduler_output.cached_requests) == 0
|
||||
elif i == 1 and use_overlap_scheduler:
|
||||
assert len(scheduler_output.new_requests) == 0
|
||||
assert len(scheduler_output.cached_requests) == 1
|
||||
|
||||
assert len(scheduler_output.cached_requests[0].new_tokens) == 0
|
||||
else:
|
||||
assert len(scheduler_output.new_requests) == 0
|
||||
assert len(scheduler_output.cached_requests) == 1
|
||||
|
||||
assert len(scheduler_output.cached_requests[0].new_tokens) == 1
|
||||
|
||||
# We call `start_load_kv` once at the beginning of each forward pass.
|
||||
assert worker.start_load_kv.call_count == NUM_TOKENS + int(
|
||||
use_overlap_scheduler)
|
||||
|
||||
# Only called once when the request is received.
|
||||
assert scheduler.get_num_new_matched_tokens.call_count == 1
|
||||
|
||||
num_layers = max(call.args[0]
|
||||
for call in worker.wait_for_layer_load.call_args_list) + 1
|
||||
|
||||
# Called num_layers * num_forward_passes times.
|
||||
assert worker.wait_for_layer_load.call_count == num_layers * (
|
||||
NUM_TOKENS + int(use_overlap_scheduler))
|
||||
assert worker.save_kv_layer.call_count == num_layers * (
|
||||
NUM_TOKENS + int(use_overlap_scheduler))
|
||||
|
||||
for i, call in enumerate(worker.wait_for_layer_load.call_args_list):
|
||||
assert call.args[0] == i % num_layers
|
||||
|
||||
for i, call in enumerate(worker.save_kv_layer.call_args_list):
|
||||
assert call.args[0] == i % num_layers
|
||||
|
||||
assert worker.wait_for_save.call_count == NUM_TOKENS + int(
|
||||
use_overlap_scheduler)
|
||||
|
||||
assert scheduler.request_finished.call_count == 1
|
||||
|
||||
assert len(scheduler.request_finished.call_args.args[1]) == 1
|
||||
|
||||
assert worker.get_finished.call_count == NUM_TOKENS + int(
|
||||
use_overlap_scheduler)
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
|
||||
def test_connector_async_onboard(enforce_single_worker, model_with_connector,
|
||||
use_overlap_scheduler):
|
||||
NUM_TOKENS = 8
|
||||
|
||||
model_fn, scheduler, worker = model_with_connector
|
||||
|
||||
model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, )
|
||||
|
||||
assert worker.register_kv_caches.call_count == 1
|
||||
|
||||
scheduler.get_num_new_matched_tokens.return_value = 16, True
|
||||
|
||||
worker.get_finished.side_effect = lambda finished_gen, load_async: (
|
||||
finished_gen, load_async)
|
||||
|
||||
model.generate([
|
||||
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
|
||||
], SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True))
|
||||
|
||||
# Once for the initial poll, then once for each token. One extra token when using the overlap scheduler.
|
||||
assert worker.get_finished.call_count == NUM_TOKENS + 1 + int(
|
||||
use_overlap_scheduler)
|
||||
|
||||
# In the first iteration, there should be a single request id provided.
|
||||
assert len(worker.get_finished.call_args_list[0].args[1]) == 1
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
|
||||
def test_connector_async_save(enforce_single_worker, model_with_connector,
|
||||
use_overlap_scheduler):
|
||||
NUM_TOKENS = 8
|
||||
|
||||
model_fn, scheduler, worker = model_with_connector
|
||||
|
||||
model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, )
|
||||
|
||||
assert worker.register_kv_caches.call_count == 1
|
||||
|
||||
scheduler.get_num_new_matched_tokens.return_value = 0, False
|
||||
|
||||
scheduler.request_finished.return_value = True
|
||||
|
||||
worker.get_finished.side_effect = lambda finished_gen, load_async: (
|
||||
finished_gen, load_async)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=NUM_TOKENS, ignore_eos=True)
|
||||
|
||||
model.generate(["Hello, world"], sampling_params)
|
||||
|
||||
assert scheduler.request_finished.call_count == 1
|
||||
|
||||
assert len(scheduler.request_finished.call_args.args[1]) == 1
|
||||
|
||||
# On the last call to get_finished, we should be providing the async saving request. One extra token when using the overlap scheduler.
|
||||
assert worker.get_finished.call_count == NUM_TOKENS + int(
|
||||
use_overlap_scheduler)
|
||||
|
||||
for i, call in enumerate(worker.get_finished.call_args_list):
|
||||
args = call.args
|
||||
if i != len(worker.get_finished.call_args_list) - 1:
|
||||
assert args == ([], [])
|
||||
else:
|
||||
assert len(args[0]) == 1
|
||||
assert args[0][0] == scheduler.request_finished.call_args.args[
|
||||
0].request_id
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
|
||||
def test_connector_scheduler_output(enforce_single_worker, model_with_connector,
|
||||
use_overlap_scheduler):
|
||||
NUM_INPUT_TOKENS = 48
|
||||
NUM_TOKENS = 32
|
||||
BLOCK_SIZE = 32
|
||||
|
||||
model_fn, scheduler, worker = model_with_connector
|
||||
|
||||
model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler, )
|
||||
|
||||
assert worker.register_kv_caches.call_count == 1
|
||||
|
||||
scheduler.get_num_new_matched_tokens.return_value = 0, False
|
||||
|
||||
worker.get_finished.return_value = [], []
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=32, ignore_eos=True)
|
||||
|
||||
model.generate([0] * NUM_INPUT_TOKENS, sampling_params)
|
||||
|
||||
assert scheduler.update_state_after_alloc.call_count == 1
|
||||
assert len(
|
||||
scheduler.update_state_after_alloc.call_args.args[1]) == math.ceil(
|
||||
NUM_INPUT_TOKENS / BLOCK_SIZE)
|
||||
|
||||
# Additional token when using the overlap scheduler.
|
||||
assert scheduler.build_connector_meta.call_count == NUM_TOKENS + int(
|
||||
use_overlap_scheduler)
|
||||
|
||||
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
|
||||
sched_output = call.args[0]
|
||||
|
||||
if i == 0:
|
||||
assert len(sched_output.new_requests) == 1
|
||||
assert len(sched_output.cached_requests) == 0
|
||||
request = sched_output.new_requests[0]
|
||||
|
||||
assert len(request.new_tokens) == NUM_INPUT_TOKENS
|
||||
assert len(request.new_block_ids) == math.ceil(NUM_INPUT_TOKENS /
|
||||
BLOCK_SIZE)
|
||||
assert request.computed_position == 0
|
||||
elif i == 1 and use_overlap_scheduler:
|
||||
assert len(sched_output.new_requests) == 0
|
||||
assert len(sched_output.cached_requests) == 1
|
||||
|
||||
assert len(sched_output.cached_requests[0].new_tokens) == 0
|
||||
else:
|
||||
assert len(sched_output.cached_requests) == 1
|
||||
assert len(sched_output.new_requests) == 0
|
||||
request = sched_output.cached_requests[0]
|
||||
|
||||
assert len(request.new_tokens) == 1
|
||||
|
||||
if (request.computed_position +
|
||||
int(use_overlap_scheduler)) % BLOCK_SIZE == 0:
|
||||
assert len(request.new_block_ids) == 1
|
||||
else:
|
||||
assert request.new_block_ids == []
|
||||
|
||||
scheduler.build_connector_meta.reset_mock()
|
||||
|
||||
scheduler.get_num_new_matched_tokens.return_value = 8, False
|
||||
|
||||
assert len(scheduler.request_finished.call_args.args[1]) == math.ceil(
|
||||
(NUM_INPUT_TOKENS + NUM_TOKENS) / BLOCK_SIZE)
|
||||
|
||||
model.generate([1] * NUM_INPUT_TOKENS, sampling_params)
|
||||
|
||||
# The initial computed position should be 0, since we haven't yet onboarded any blocks.
|
||||
assert scheduler.build_connector_meta.call_args_list[0].args[
|
||||
0].new_requests[0].computed_position == 0
|
||||
|
||||
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
@pytest.mark.parametrize("use_overlap_scheduler", [True, False])
|
||||
def test_connector_scheduler_output_chunked_context(enforce_single_worker,
|
||||
model_with_connector,
|
||||
use_overlap_scheduler):
|
||||
model_fn, scheduler, worker = model_with_connector
|
||||
|
||||
CHUNK_SIZE = 128
|
||||
BLOCK_SIZE = 32
|
||||
|
||||
model = model_fn(disable_overlap_scheduler=not use_overlap_scheduler,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_tokens=CHUNK_SIZE)
|
||||
|
||||
assert worker.register_kv_caches.call_count == 1
|
||||
|
||||
scheduler.get_num_new_matched_tokens.return_value = 0, False
|
||||
|
||||
worker.get_finished.return_value = [], []
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=BLOCK_SIZE, ignore_eos=True)
|
||||
|
||||
model.generate([0] * (CHUNK_SIZE * 2), sampling_params)
|
||||
|
||||
assert scheduler.update_state_after_alloc.call_count == 1
|
||||
|
||||
assert len(
|
||||
scheduler.update_state_after_alloc.call_args.args[1]) == math.ceil(
|
||||
CHUNK_SIZE * 2 / BLOCK_SIZE)
|
||||
|
||||
for i, call in enumerate(scheduler.build_connector_meta.call_args_list):
|
||||
sched_output = call.args[0]
|
||||
|
||||
if i == 0:
|
||||
assert len(sched_output.new_requests) == 1
|
||||
assert len(sched_output.cached_requests) == 0
|
||||
req = sched_output.new_requests[0]
|
||||
else:
|
||||
assert len(sched_output.cached_requests) == 1
|
||||
assert len(sched_output.new_requests) == 0
|
||||
req = sched_output.cached_requests[0]
|
||||
|
||||
if i == 0:
|
||||
# The first prefill chunk.
|
||||
# All of the prefill tokens and all the blocks should be provided upfront.
|
||||
assert req.computed_position == 0
|
||||
assert len(req.new_tokens) == CHUNK_SIZE * 2
|
||||
assert len(req.new_block_ids) == math.ceil(CHUNK_SIZE * 2 /
|
||||
BLOCK_SIZE)
|
||||
elif i == 1:
|
||||
# The second prefill chunk.
|
||||
assert req.computed_position == CHUNK_SIZE
|
||||
assert len(req.new_tokens) == 0
|
||||
assert len(req.new_block_ids) == 0
|
||||
elif i == 2 and use_overlap_scheduler:
|
||||
assert len(req.new_tokens) == 0
|
||||
else:
|
||||
assert len(req.new_tokens) == 1
|
||||
|
||||
assert len(scheduler.request_finished.call_args.args[1]) == math.ceil(
|
||||
(CHUNK_SIZE * 2 + BLOCK_SIZE) / BLOCK_SIZE)
|
||||
@ -163,3 +163,12 @@ def test_llmapi_sampling(llm_root, engine_dir, llm_venv):
|
||||
@pytest.mark.skip(reason="https://nvbugs/5365825")
|
||||
def test_llmapi_runtime(llm_root, engine_dir, llm_venv):
|
||||
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_runtime.py")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["Qwen2-0.5B"])
|
||||
def test_llmapi_kv_cache_connector(llm_root, llm_venv, model):
|
||||
script_path = Path(
|
||||
llm_root) / "examples" / "llm-api" / "llm_kv_cache_connector.py"
|
||||
model_path = f"{llm_models_root()}/{model}"
|
||||
|
||||
venv_check_call(llm_venv, [str(script_path), model_path])
|
||||
|
||||
@ -81,12 +81,23 @@ l0_a10:
|
||||
- unittest/trt/model/test_mistral.py
|
||||
- unittest/trt/model/test_llama.py
|
||||
- test_e2e.py::test_gpt3_175b_1layers_build_only # 6 mins
|
||||
- llmapi/test_llm_api_connector.py::test_connector_simple[True]
|
||||
- llmapi/test_llm_api_connector.py::test_connector_simple[False]
|
||||
- llmapi/test_llm_api_connector.py::test_connector_async_onboard[True]
|
||||
- llmapi/test_llm_api_connector.py::test_connector_async_onboard[False]
|
||||
- llmapi/test_llm_api_connector.py::test_connector_async_save[True]
|
||||
- llmapi/test_llm_api_connector.py::test_connector_async_save[False]
|
||||
- llmapi/test_llm_api_connector.py::test_connector_scheduler_output[True]
|
||||
- llmapi/test_llm_api_connector.py::test_connector_scheduler_output[False]
|
||||
- llmapi/test_llm_api_connector.py::test_connector_scheduler_output_chunked_context[True]
|
||||
- llmapi/test_llm_api_connector.py::test_connector_scheduler_output_chunked_context[False]
|
||||
- llmapi/test_llm_e2e.py::test_llmapi_load_engine_from_build_command[llama-llama-models/llama-7b-hf] # 5min
|
||||
- llmapi/test_llm_e2e.py::test_llmapi_build_command_parameters_align[llama-llama-models-v2/TinyLlama-1.1B-Chat-v1.0]
|
||||
- llmapi/test_llm_e2e.py::test_llmapi_load_engine_from_build_command_with_lora[llama-llama-models-v2/llama-v2-7b-hf]
|
||||
- llmapi/test_llm_examples.py::test_llmapi_chat_example
|
||||
- llmapi/test_llm_e2e.py::test_llmapi_exit
|
||||
- llmapi/test_llm_examples.py::test_llmapi_server_example
|
||||
- llmapi/test_llm_examples.py::test_llmapi_kv_cache_connector[Qwen2-0.5B]
|
||||
- test_e2e.py::test_trtllm_serve_example
|
||||
- test_e2e.py::test_openai_misc_example[trt]
|
||||
- test_e2e.py::test_openai_completions_example[trt]
|
||||
|
||||
170
tests/unittest/_torch/test_connector.py
Normal file
170
tests/unittest/_torch/test_connector.py
Normal file
@ -0,0 +1,170 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import cloudpickle
|
||||
import mpi4py
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm import mpi_rank
|
||||
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import \
|
||||
KvCacheConnectorManager
|
||||
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
|
||||
|
||||
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
||||
mpi4py.MPI.pickle.__init__(
|
||||
cloudpickle.dumps,
|
||||
cloudpickle.loads,
|
||||
pickle.HIGHEST_PROTOCOL,
|
||||
)
|
||||
|
||||
|
||||
def run_across_mpi(executor, fun, num_ranks):
|
||||
return list(executor.starmap(fun, [() for i in range(num_ranks)]))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
||||
# TODO(jthomson04): I don't have the slightest idea why this test is leaking threads.
|
||||
@pytest.mark.threadleak(enabled=False)
|
||||
def test_connector_manager_get_finished_allgather(mpi_pool_executor):
|
||||
|
||||
def test():
|
||||
worker = MagicMock()
|
||||
|
||||
if mpi_rank() == 0:
|
||||
scheduler = MagicMock()
|
||||
|
||||
scheduler.request_finished.return_value = True
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
|
||||
|
||||
req = MagicMock()
|
||||
|
||||
req.request_id = 42
|
||||
|
||||
manager.request_finished(req, [])
|
||||
|
||||
# To start, make both workers return nothing.
|
||||
worker.get_finished.return_value = ([], [])
|
||||
|
||||
assert manager.get_finished() == []
|
||||
|
||||
assert worker.get_finished.call_count == 1
|
||||
assert worker.get_finished.call_args[0] == ([42], [])
|
||||
|
||||
worker.get_finished.reset_mock()
|
||||
|
||||
# Now, only return the request id on one worker.
|
||||
if mpi_rank() == 0:
|
||||
worker.get_finished.return_value = ([42], [])
|
||||
else:
|
||||
worker.get_finished.return_value = ([], [])
|
||||
|
||||
# It should still return nothing, since rank 1 is still saving.
|
||||
assert manager.get_finished() == []
|
||||
|
||||
assert worker.get_finished.call_count == 1
|
||||
assert worker.get_finished.call_args[0] == ([], [])
|
||||
|
||||
# Now, also return it on worker 1.
|
||||
if mpi_rank() == 0:
|
||||
worker.get_finished.return_value = ([], [])
|
||||
else:
|
||||
worker.get_finished.return_value = ([42], [])
|
||||
|
||||
assert manager.get_finished() == [req]
|
||||
|
||||
run_across_mpi(mpi_pool_executor, test, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
||||
def test_connector_manager_num_matched_tokens(mpi_pool_executor):
|
||||
|
||||
def test():
|
||||
worker = MagicMock()
|
||||
|
||||
if mpi_rank() == 0:
|
||||
scheduler = MagicMock()
|
||||
scheduler.get_num_new_matched_tokens.return_value = (16, True)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
|
||||
|
||||
req = MagicMock()
|
||||
|
||||
req.request_id = 42
|
||||
|
||||
assert manager.get_num_new_matched_tokens(req, 32) == 16
|
||||
|
||||
if mpi_rank() == 0:
|
||||
assert scheduler.get_num_new_matched_tokens.call_count == 1
|
||||
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req,
|
||||
32)
|
||||
|
||||
run_across_mpi(mpi_pool_executor, test, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
||||
def test_connector_manager_take_scheduled_requests(mpi_pool_executor):
|
||||
|
||||
def test():
|
||||
worker = MagicMock()
|
||||
|
||||
if mpi_rank() == 0:
|
||||
scheduler = MagicMock()
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
manager = KvCacheConnectorManager(worker, scheduler=scheduler)
|
||||
|
||||
scheduled_requests = ScheduledRequests()
|
||||
|
||||
req0 = MagicMock()
|
||||
req0.request_id = 0
|
||||
|
||||
req1 = MagicMock()
|
||||
req1.request_id = 1
|
||||
|
||||
if mpi_rank() == 0:
|
||||
scheduler.get_num_new_matched_tokens.return_value = (16, True)
|
||||
|
||||
assert manager.get_num_new_matched_tokens(req0, 0) == 16
|
||||
if mpi_rank() == 0:
|
||||
assert scheduler.get_num_new_matched_tokens.call_count == 1
|
||||
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req0,
|
||||
0)
|
||||
|
||||
scheduler.get_num_new_matched_tokens.reset_mock()
|
||||
scheduler.get_num_new_matched_tokens.return_value = (32, False)
|
||||
|
||||
assert manager.get_num_new_matched_tokens(req1, 0) == 32
|
||||
if mpi_rank() == 0:
|
||||
assert scheduler.get_num_new_matched_tokens.call_count == 1
|
||||
assert scheduler.get_num_new_matched_tokens.call_args[0] == (req1,
|
||||
0)
|
||||
|
||||
scheduled_requests.context_requests = [req0, req1]
|
||||
|
||||
manager.take_scheduled_requests_pending_load(scheduled_requests)
|
||||
|
||||
assert scheduled_requests.context_requests == [req1]
|
||||
|
||||
run_across_mpi(mpi_pool_executor, test, 2)
|
||||
@ -167,6 +167,10 @@ methods:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_args.DecodingConfig]
|
||||
default: null
|
||||
status: deprecated
|
||||
kv_connector_config:
|
||||
annotation: Optional[tensorrt_llm.llmapi.llm_args.KvCacheConnectorConfig]
|
||||
default: null
|
||||
status: prototype
|
||||
return_annotation: None
|
||||
generate:
|
||||
parameters:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user