mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-10197][chore] Refactor to setup for RNN cache transceiver (#10957)
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
This commit is contained in:
parent
f25a2c53bb
commit
6c1862fb33
@ -21,6 +21,7 @@ set(TARGET_DIR ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
set(SRCS
|
||||
allocateKvCache.cpp
|
||||
assignReqSeqSlots.cpp
|
||||
baseTransBuffer.cpp
|
||||
cacheFormatter.cpp
|
||||
mlaCacheFormatter.cpp
|
||||
cacheTransceiver.cpp
|
||||
|
||||
285
cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp
Normal file
285
cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp
Normal file
@ -0,0 +1,285 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "baseTransBuffer.h"
|
||||
#include "cacheTransBuffer.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
|
||||
#include <mutex>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
BaseTransBufferManager::BaseTransBufferManager(
|
||||
size_t transferBufferSize, nvinfer1::DataType dataType, std::optional<size_t> maxNumTokens)
|
||||
: mDataType{dataType}
|
||||
, mBufferManager{std::make_shared<runtime::CudaStream>()}
|
||||
, mMaxNumTokens{maxNumTokens}
|
||||
{
|
||||
mTransferBufferSize = transferBufferSize;
|
||||
mOnlyUseDynamicBuffer = mTransferBufferSize == 0;
|
||||
mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
|
||||
mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum();
|
||||
mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())
|
||||
&& kv_cache_manager::FabricMemory::supportFbaricMemory();
|
||||
if (mUseFabricMemory)
|
||||
{
|
||||
mTransferBufferSize = kv_cache_manager::FabricMemory::getAlignedSize(mTransferBufferSize);
|
||||
}
|
||||
mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount);
|
||||
|
||||
TLLM_LOG_INFO(
|
||||
"BaseTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, "
|
||||
"mSendBufferCount:%ld, mTransferBufferSize:%ld, mPreAllocBufferSize:%ld, mOnlyUseDynamicBuffer:%d, "
|
||||
"mUseFabricMemory:%d, mDataType:%d",
|
||||
maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize,
|
||||
mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, static_cast<int>(mDataType));
|
||||
|
||||
allocateBuffer();
|
||||
}
|
||||
|
||||
std::optional<int> BaseTransBufferManager::assignBufferIndexForSend()
|
||||
{
|
||||
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer);
|
||||
}
|
||||
|
||||
void BaseTransBufferManager::freeBufferIndexForSend(std::optional<int> bufferId)
|
||||
{
|
||||
freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer);
|
||||
}
|
||||
|
||||
std::optional<int> BaseTransBufferManager::assignBufferIndexForRecv()
|
||||
{
|
||||
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer);
|
||||
}
|
||||
|
||||
void BaseTransBufferManager::freeBufferIndexForRecv(std::optional<int> bufferId)
|
||||
{
|
||||
freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> BaseTransBufferManager::getOrAllocateSendBuffers(
|
||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse)
|
||||
{
|
||||
return getOrAllocateBuffers(
|
||||
bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceSendResource);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> BaseTransBufferManager::getOrAllocateRecvBuffers(
|
||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse)
|
||||
{
|
||||
return getOrAllocateBuffers(
|
||||
bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceRecvResource);
|
||||
}
|
||||
|
||||
runtime::ITensor::SharedPtr BaseTransBufferManager::getSendBuffer(std::optional<int> bufferId)
|
||||
{
|
||||
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
||||
if (bufferId.has_value())
|
||||
{
|
||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < mSendBufferCount);
|
||||
return mConcurrenceSendResource.mBuffers[bufferId.value()];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
runtime::ITensor::SharedPtr BaseTransBufferManager::getRecvBuffer(std::optional<int> bufferId)
|
||||
{
|
||||
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
||||
if (bufferId.has_value())
|
||||
{
|
||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < mRecvBufferCount);
|
||||
// TLLM_CHECK(mConcurrenceRecvResource.mBufferIndexFlag[bufferId.value()] == 1);
|
||||
return mConcurrenceRecvResource.mBuffers[bufferId.value()];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> BaseTransBufferManager::getOrAllocateBuffers(
|
||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource)
|
||||
{
|
||||
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
||||
TLLM_CHECK(requestedNumberOfElements.size() >= static_cast<size_t>(targetNum));
|
||||
std::vector<runtime::ITensor::SharedPtr> retSplitCaches;
|
||||
|
||||
size_t bufferCoverTargetNum = 0;
|
||||
|
||||
if (bufferId.has_value())
|
||||
{
|
||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < concurrenceResource.mBuffers.size());
|
||||
TLLM_CHECK(concurrenceResource.mBufferIndexFlag[bufferId.value()] == 1);
|
||||
size_t preBufferEleSize = 0;
|
||||
for (int i = 0; i < targetNum; i++)
|
||||
{
|
||||
// Strict checking.
|
||||
if (preBufferEleSize + requestedNumberOfElements[i] <= mNumberOfElements)
|
||||
{
|
||||
auto slice = runtime::ITensor::slice(
|
||||
concurrenceResource.mBuffers[bufferId.value()], preBufferEleSize, requestedNumberOfElements[i]);
|
||||
preBufferEleSize += requestedNumberOfElements[i];
|
||||
bufferCoverTargetNum++;
|
||||
retSplitCaches.push_back(std::move(slice));
|
||||
}
|
||||
else
|
||||
{
|
||||
retSplitCaches.push_back(bufferManagerToUse.gpu(
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(requestedNumberOfElements[i])}), mDataType));
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum);
|
||||
if (bufferCoverTargetNum < static_cast<size_t>(targetNum))
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic "
|
||||
"buffer which will fail with NIXL backend. It is recommended to set "
|
||||
"cacheTransceiverConfig.MaxTokensInBuffer (cache_transceiver_config.max_tokens_in_buffer in config "
|
||||
"YAML file) to a value greater than the maximum ISL of the processed requests. Otherwise, performance "
|
||||
"may be degraded or transfer may fail. requestedNumberOfElements.size():%ld, "
|
||||
"mNumberOfElements:%ld, requestedNumberOfElements[0]:%ld",
|
||||
bufferCoverTargetNum, targetNum, requestedNumberOfElements.size(), mNumberOfElements,
|
||||
requestedNumberOfElements[0]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < targetNum; i++)
|
||||
{
|
||||
retSplitCaches.push_back(bufferManagerToUse.gpu(
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(requestedNumberOfElements[i])}), mDataType));
|
||||
}
|
||||
bufferCoverTargetNum = targetNum;
|
||||
}
|
||||
|
||||
return std::make_tuple(retSplitCaches, bufferCoverTargetNum, mOnlyUseDynamicBuffer);
|
||||
}
|
||||
|
||||
void BaseTransBufferManager::allocateBuffer()
|
||||
{
|
||||
if (mOnlyUseDynamicBuffer)
|
||||
{
|
||||
return;
|
||||
}
|
||||
mNumberOfElements = mTransferBufferSize / common::getDTypeSize(mDataType);
|
||||
mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0);
|
||||
mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0);
|
||||
if (mUseFabricMemory)
|
||||
{
|
||||
mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount);
|
||||
for (size_t i = 0; i < mSendBufferCount; i++)
|
||||
{
|
||||
mFabricMemory.emplace_back(std::make_unique<kv_cache_manager::FabricMemory>(mTransferBufferSize));
|
||||
mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mNumberOfElements);
|
||||
}
|
||||
for (size_t i = 0; i < mRecvBufferCount; i++)
|
||||
{
|
||||
mFabricMemory.emplace_back(std::make_unique<kv_cache_manager::FabricMemory>(mTransferBufferSize));
|
||||
mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mNumberOfElements);
|
||||
}
|
||||
}
|
||||
else if (common::getEnvKVCacheTransferUseAsyncBuffer())
|
||||
{
|
||||
for (size_t i = 0; i < mSendBufferCount; i++)
|
||||
{
|
||||
mConcurrenceSendResource.mBuffers[i]
|
||||
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||
}
|
||||
for (size_t i = 0; i < mRecvBufferCount; i++)
|
||||
{
|
||||
mConcurrenceRecvResource.mBuffers[i]
|
||||
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||
}
|
||||
mBufferManager.getStream().synchronize();
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < mSendBufferCount; i++)
|
||||
{
|
||||
mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync(
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||
}
|
||||
for (size_t i = 0; i < mRecvBufferCount; i++)
|
||||
{
|
||||
mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync(
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<int> BaseTransBufferManager::assignBufferIndex(
|
||||
ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer)
|
||||
{
|
||||
if (onlyUseDynamicBuffer)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
std::unique_lock lk(resource.mBuffersMutex);
|
||||
resource.mBuffersCV.wait(
|
||||
lk, [&resource, bufferCount]() { return static_cast<size_t>(resource.mConcurrence) < bufferCount; });
|
||||
int bufferId = -1;
|
||||
for (size_t i = 0; i < bufferCount; i++)
|
||||
{
|
||||
if (resource.mBufferIndexFlag[i] == 0)
|
||||
{
|
||||
bufferId = i;
|
||||
resource.mBufferIndexFlag[bufferId] = 1;
|
||||
resource.mConcurrence++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(bufferId >= 0 && static_cast<size_t>(bufferId) < bufferCount,
|
||||
" assignBufferIndex: Buffer index already assigned");
|
||||
|
||||
return bufferId;
|
||||
}
|
||||
|
||||
void BaseTransBufferManager::freeBufferIndex(
|
||||
ConcurrenceResource& resource, std::optional<int> bufferId, size_t bufferCount, bool onlyUseDynamicBuffer)
|
||||
{
|
||||
if (onlyUseDynamicBuffer)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if (bufferId.has_value())
|
||||
{
|
||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < bufferCount);
|
||||
{
|
||||
std::scoped_lock lk(resource.mBuffersMutex);
|
||||
resource.mBufferIndexFlag[bufferId.value()] = 0;
|
||||
}
|
||||
resource.mConcurrence--;
|
||||
resource.mBuffersCV.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
size_t BaseTransBufferManager::getRecvBufferCount()
|
||||
{
|
||||
return mRecvBufferCount;
|
||||
}
|
||||
|
||||
size_t BaseTransBufferManager::getSendBufferCount()
|
||||
{
|
||||
return mSendBufferCount;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
144
cpp/tensorrt_llm/batch_manager/baseTransBuffer.h
Normal file
144
cpp/tensorrt_llm/batch_manager/baseTransBuffer.h
Normal file
@ -0,0 +1,144 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
class FabricMemory;
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
/// @brief Base class for cache transfer buffer management.
|
||||
/// Handles buffer pool allocation, index assignment, and slicing.
|
||||
/// Derived classes provide cache-specific size calculations.
|
||||
class BaseTransBufferManager
|
||||
{
|
||||
public:
|
||||
virtual ~BaseTransBufferManager() = default;
|
||||
|
||||
/// @brief Assign a buffer index for sending.
|
||||
/// @return Assigned buffer index, or nullopt if using dynamic buffers.
|
||||
std::optional<int> assignBufferIndexForSend();
|
||||
|
||||
/// @brief Free a buffer index used for sending.
|
||||
/// @param bufferId The buffer index to free.
|
||||
void freeBufferIndexForSend(std::optional<int> bufferId);
|
||||
|
||||
/// @brief Assign a buffer index for receiving.
|
||||
/// @return Assigned buffer index, or nullopt if using dynamic buffers.
|
||||
std::optional<int> assignBufferIndexForRecv();
|
||||
|
||||
/// @brief Free a buffer index used for receiving.
|
||||
/// @param bufferId The buffer index to free.
|
||||
void freeBufferIndexForRecv(std::optional<int> bufferId);
|
||||
|
||||
/// @brief Get or allocate send buffers for cache transfer.
|
||||
/// @param bufferId The assigned buffer ID.
|
||||
/// @param targetNum Number of target sequences.
|
||||
/// @param requestedNumberOfElements Sizes requested for each target.
|
||||
/// @param bufferManagerToUse Buffer manager for dynamic allocation.
|
||||
/// @return Tuple of (buffers, covered target count, is dynamic only).
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateSendBuffers(
|
||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse);
|
||||
|
||||
/// @brief Get or allocate receive buffers for cache transfer.
|
||||
/// @param bufferId The assigned buffer ID.
|
||||
/// @param targetNum Number of target sequences.
|
||||
/// @param requestedNumberOfElements Sizes requested for each target.
|
||||
/// @param bufferManagerToUse Buffer manager for dynamic allocation.
|
||||
/// @return Tuple of (buffers, covered target count, is dynamic only).
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateRecvBuffers(
|
||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse);
|
||||
|
||||
/// @brief Get the send buffer for a given buffer ID.
|
||||
runtime::ITensor::SharedPtr getSendBuffer(std::optional<int> bufferId);
|
||||
|
||||
/// @brief Get the receive buffer for a given buffer ID.
|
||||
runtime::ITensor::SharedPtr getRecvBuffer(std::optional<int> bufferId);
|
||||
|
||||
/// @brief Get the number of receive buffers.
|
||||
size_t getRecvBufferCount();
|
||||
|
||||
/// @brief Get the number of send buffers.
|
||||
size_t getSendBufferCount();
|
||||
|
||||
/// @brief Get the maximum number of tokens configured.
|
||||
std::optional<size_t> getMaxNumTokens()
|
||||
{
|
||||
return mMaxNumTokens;
|
||||
}
|
||||
|
||||
protected:
|
||||
/// @brief Constructor - derived classes call this after computing buffer sizes.
|
||||
/// @param transferBufferSize Size of each transfer buffer in bytes.
|
||||
/// @param dataType Data type for the buffers.
|
||||
/// @param maxNumTokens Optional max tokens for sizing.
|
||||
BaseTransBufferManager(
|
||||
size_t transferBufferSize, nvinfer1::DataType dataType, std::optional<size_t> maxNumTokens = std::nullopt);
|
||||
|
||||
struct ConcurrenceResource
|
||||
{
|
||||
std::unordered_map<int, runtime::ITensor::SharedPtr> mBuffers;
|
||||
std::vector<int> mBufferIndexFlag;
|
||||
std::mutex mBuffersMutex;
|
||||
std::condition_variable mBuffersCV;
|
||||
std::atomic<int> mConcurrence{0};
|
||||
};
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateBuffers(std::optional<int> bufferId,
|
||||
int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource);
|
||||
|
||||
void allocateBuffer();
|
||||
std::optional<int> assignBufferIndex(ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer);
|
||||
void freeBufferIndex(
|
||||
ConcurrenceResource& resource, std::optional<int> bufferId, size_t bufferCount, bool onlyUseDynamicBuffer);
|
||||
|
||||
size_t mPreAllocBufferSize;
|
||||
size_t mRecvBufferCount;
|
||||
size_t mSendBufferCount;
|
||||
size_t mTransferBufferSize;
|
||||
bool mOnlyUseDynamicBuffer;
|
||||
bool mUseFabricMemory;
|
||||
size_t mNumberOfElements;
|
||||
nvinfer1::DataType mDataType;
|
||||
ConcurrenceResource mConcurrenceSendResource;
|
||||
ConcurrenceResource mConcurrenceRecvResource;
|
||||
runtime::BufferManager mBufferManager;
|
||||
std::vector<std::unique_ptr<kv_cache_manager::FabricMemory>> mFabricMemory;
|
||||
std::optional<size_t> mMaxNumTokens;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
@ -20,12 +20,17 @@
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
|
||||
#include <NvInferRuntimeBase.h>
|
||||
#include <mutex>
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
|
||||
// ============================================================================
|
||||
// FabricMemory Implementation
|
||||
// ============================================================================
|
||||
|
||||
class FabricMemory::Impl
|
||||
{
|
||||
public:
|
||||
@ -182,45 +187,46 @@ bool FabricMemory::supportFbaricMemory()
|
||||
#endif
|
||||
}
|
||||
|
||||
CacheTransBufferManager::CacheTransBufferManager(
|
||||
// ============================================================================
|
||||
// CacheTransBufferManager Implementation
|
||||
// ============================================================================
|
||||
|
||||
size_t CacheTransBufferManager::computeTransferBufferSize(
|
||||
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens, bool transferIndexerKCache)
|
||||
: mCacheManager{cacheManager}
|
||||
, mBufferManager{std::make_shared<runtime::CudaStream>()}
|
||||
, mMaxNumTokens{maxNumTokens}
|
||||
{
|
||||
// TODO: FP4 dataSize
|
||||
TLLM_CHECK(mCacheManager);
|
||||
nvinfer1::DataType dataType;
|
||||
if (transferIndexerKCache)
|
||||
{
|
||||
mDataType = mCacheManager->getIndexerKCachePool()->getDataType();
|
||||
dataType = cacheManager->getIndexerKCachePool()->getDataType();
|
||||
}
|
||||
else
|
||||
{
|
||||
mDataType = mCacheManager->getPrimaryPool(0)->getDataType();
|
||||
dataType = cacheManager->getPrimaryPool(0)->getDataType();
|
||||
}
|
||||
|
||||
auto tokensPerBlock = mCacheManager->getBlockManager().getTokensPerBlock();
|
||||
auto tokensPerBlock = cacheManager->getBlockManager().getTokensPerBlock();
|
||||
size_t bufferSizeFromMaxNumToken = 0;
|
||||
|
||||
if (maxNumTokens.has_value())
|
||||
{
|
||||
TLLM_CHECK(maxNumTokens.value() % tokensPerBlock == 0);
|
||||
auto dataSize = common::getDTypeSize(mDataType);
|
||||
auto dataSize = common::getDTypeSize(dataType);
|
||||
SizeType32 kvCacheByteSizePerTokenPerLayer = 0;
|
||||
if (transferIndexerKCache)
|
||||
{
|
||||
kvCacheByteSizePerTokenPerLayer
|
||||
= mCacheManager->getIndexerKCachePool()->getDimension<-1>() * dataSize / tokensPerBlock;
|
||||
= cacheManager->getIndexerKCachePool()->getDimension<-1>() * dataSize / tokensPerBlock;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto primaryPool = mCacheManager->getPrimaryPool(0);
|
||||
auto primaryPool = cacheManager->getPrimaryPool(0);
|
||||
kvCacheByteSizePerTokenPerLayer
|
||||
= primaryPool->getDimension<-1>() * primaryPool->getDimension<2>() * dataSize / tokensPerBlock;
|
||||
}
|
||||
for (auto layerId = 0; layerId < mCacheManager->getBlockManager().getNumLayers(); layerId++)
|
||||
for (auto layerId = 0; layerId < cacheManager->getBlockManager().getNumLayers(); layerId++)
|
||||
{
|
||||
auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId);
|
||||
auto windowSize = static_cast<size_t>(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx));
|
||||
auto poolIdx = cacheManager->getBlockManager().getLayerPoolIdx(layerId);
|
||||
auto windowSize = static_cast<size_t>(cacheManager->getBlockManager().getPoolWindowSize(poolIdx));
|
||||
auto alignedWindowSize = (windowSize + tokensPerBlock - 1) / tokensPerBlock * tokensPerBlock;
|
||||
auto validTokenNum = (alignedWindowSize < maxNumTokens.value() ? alignedWindowSize : maxNumTokens.value());
|
||||
if (common::getEnvKVCacheTransferAllBlocksForWindow())
|
||||
@ -233,26 +239,20 @@ CacheTransBufferManager::CacheTransBufferManager(
|
||||
}
|
||||
}
|
||||
|
||||
mTransferBufferSize
|
||||
= maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer();
|
||||
mOnlyUseDynamicBuffer = mTransferBufferSize == 0;
|
||||
mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1;
|
||||
mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum();
|
||||
mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())
|
||||
&& FabricMemory::supportFbaricMemory();
|
||||
if (mUseFabricMemory)
|
||||
{
|
||||
mTransferBufferSize = FabricMemory::getAlignedSize(mTransferBufferSize);
|
||||
}
|
||||
mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount);
|
||||
TLLM_LOG_INFO(
|
||||
"CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, "
|
||||
"mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d "
|
||||
"mUseFabricMemory:%d mDataType:%d",
|
||||
maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize,
|
||||
mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, mDataType);
|
||||
return maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer();
|
||||
}
|
||||
|
||||
allocateBuffer();
|
||||
CacheTransBufferManager::CacheTransBufferManager(
|
||||
KVCacheManager::BaseKVCacheManager* cacheManager, std::optional<size_t> maxNumTokens, bool transferIndexerKCache)
|
||||
: BaseTransBufferManager(computeTransferBufferSize(cacheManager, maxNumTokens, transferIndexerKCache),
|
||||
transferIndexerKCache ? cacheManager->getIndexerKCachePool()->getDataType()
|
||||
: cacheManager->getPrimaryPool(0)->getDataType(),
|
||||
maxNumTokens)
|
||||
, mCacheManager{cacheManager}
|
||||
{
|
||||
// TODO: FP4 dataSize
|
||||
TLLM_CHECK(mCacheManager);
|
||||
TLLM_LOG_INFO("CacheTransBufferManager created for KV cache");
|
||||
}
|
||||
|
||||
size_t CacheTransBufferManager::preAllocBufferSize(
|
||||
@ -298,233 +298,4 @@ size_t CacheTransBufferManager::preAllocBufferSize(
|
||||
return preAllocBufferSize;
|
||||
}
|
||||
|
||||
std::optional<int> CacheTransBufferManager::assignBufferIndexForSend()
|
||||
{
|
||||
return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer);
|
||||
}
|
||||
|
||||
void CacheTransBufferManager::freeBufferIndexForSend(std::optional<int> bufferId)
|
||||
{
|
||||
freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer);
|
||||
}
|
||||
|
||||
std::optional<int> CacheTransBufferManager::assignBufferIndexForRecv()
|
||||
{
|
||||
return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer);
|
||||
}
|
||||
|
||||
void CacheTransBufferManager::freeBufferIndexForRecv(std::optional<int> bufferId)
|
||||
{
|
||||
freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBufferManager::getOrAllocateSendBuffers(
|
||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse)
|
||||
{
|
||||
return getOrAllocateBuffers(
|
||||
bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceSendResource);
|
||||
}
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBufferManager::getOrAllocateRecvBuffers(
|
||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse)
|
||||
{
|
||||
return getOrAllocateBuffers(
|
||||
bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceRecvResource);
|
||||
}
|
||||
|
||||
runtime::ITensor::SharedPtr CacheTransBufferManager::getSendBuffer(std::optional<int> bufferId)
|
||||
{
|
||||
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
||||
if (bufferId.has_value())
|
||||
{
|
||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < mSendBufferCount);
|
||||
return mConcurrenceSendResource.mBuffers[bufferId.value()];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
runtime::ITensor::SharedPtr CacheTransBufferManager::getRecvBuffer(std::optional<int> bufferId)
|
||||
{
|
||||
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
||||
if (bufferId.has_value())
|
||||
{
|
||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < mRecvBufferCount);
|
||||
// TLLM_CHECK(mConcurrenceRecvResource.mBufferIndexFlag[bufferId.value()] == 1);
|
||||
return mConcurrenceRecvResource.mBuffers[bufferId.value()];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> CacheTransBufferManager::getOrAllocateBuffers(
|
||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource)
|
||||
{
|
||||
TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer);
|
||||
TLLM_CHECK(requestedNumberOfElements.size() >= static_cast<size_t>(targetNum));
|
||||
std::vector<runtime::ITensor::SharedPtr> retSplitCaches;
|
||||
|
||||
size_t bufferCoverTargetNum = 0;
|
||||
|
||||
if (bufferId.has_value())
|
||||
{
|
||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < concurrenceResource.mBuffers.size());
|
||||
TLLM_CHECK(concurrenceResource.mBufferIndexFlag[bufferId.value()] == 1);
|
||||
size_t preBufferEleSize = 0;
|
||||
for (int i = 0; i < targetNum; i++)
|
||||
{
|
||||
// Strict checking.
|
||||
if (preBufferEleSize + requestedNumberOfElements[i] <= mNumberOfElements)
|
||||
{
|
||||
auto slice = runtime::ITensor::slice(
|
||||
concurrenceResource.mBuffers[bufferId.value()], preBufferEleSize, requestedNumberOfElements[i]);
|
||||
preBufferEleSize += requestedNumberOfElements[i];
|
||||
bufferCoverTargetNum++;
|
||||
retSplitCaches.push_back(std::move(slice));
|
||||
}
|
||||
else
|
||||
{
|
||||
retSplitCaches.push_back(bufferManagerToUse.gpu(
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(requestedNumberOfElements[i])}), mDataType));
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum);
|
||||
if (bufferCoverTargetNum < static_cast<size_t>(targetNum))
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic "
|
||||
"buffer which will fail with NIXL backend. It is recommended to set "
|
||||
"cacheTransceiverConfig.MaxTokensInBuffer (cache_transceiver_config.max_tokens_in_buffer in config "
|
||||
"YAML file) to a value greater than the maximum ISL of the processed requests. Otherwise, performance "
|
||||
"may be degraded or transfer may fail. requestedNumberOfElements.size():%ld, "
|
||||
"mNumberOfElements:%ld, requestedNumberOfElements[0]:%ld",
|
||||
bufferCoverTargetNum, targetNum, requestedNumberOfElements.size(), mNumberOfElements,
|
||||
requestedNumberOfElements[0]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int i = 0; i < targetNum; i++)
|
||||
{
|
||||
retSplitCaches.push_back(bufferManagerToUse.gpu(
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(requestedNumberOfElements[i])}), mDataType));
|
||||
}
|
||||
bufferCoverTargetNum = targetNum;
|
||||
}
|
||||
|
||||
return std::make_tuple(retSplitCaches, bufferCoverTargetNum, mOnlyUseDynamicBuffer);
|
||||
}
|
||||
|
||||
void CacheTransBufferManager::allocateBuffer()
|
||||
{
|
||||
if (mOnlyUseDynamicBuffer)
|
||||
{
|
||||
return;
|
||||
}
|
||||
mNumberOfElements = mTransferBufferSize / common::getDTypeSize(mDataType);
|
||||
mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0);
|
||||
mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0);
|
||||
if (mUseFabricMemory)
|
||||
{
|
||||
mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount);
|
||||
for (size_t i = 0; i < mSendBufferCount; i++)
|
||||
{
|
||||
mFabricMemory.emplace_back(std::make_unique<FabricMemory>(mTransferBufferSize));
|
||||
mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mNumberOfElements);
|
||||
}
|
||||
for (size_t i = 0; i < mRecvBufferCount; i++)
|
||||
{
|
||||
mFabricMemory.emplace_back(std::make_unique<FabricMemory>(mTransferBufferSize));
|
||||
mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType,
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mNumberOfElements);
|
||||
}
|
||||
}
|
||||
else if (common::getEnvKVCacheTransferUseAsyncBuffer())
|
||||
{
|
||||
for (size_t i = 0; i < mSendBufferCount; i++)
|
||||
{
|
||||
mConcurrenceSendResource.mBuffers[i]
|
||||
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||
}
|
||||
for (size_t i = 0; i < mRecvBufferCount; i++)
|
||||
{
|
||||
mConcurrenceRecvResource.mBuffers[i]
|
||||
= mBufferManager.gpu(runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||
}
|
||||
mBufferManager.getStream().synchronize();
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < mSendBufferCount; i++)
|
||||
{
|
||||
mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync(
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||
}
|
||||
for (size_t i = 0; i < mRecvBufferCount; i++)
|
||||
{
|
||||
mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync(
|
||||
runtime::ITensor::makeShape({static_cast<int64_t>(mNumberOfElements)}), mDataType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<int> CacheTransBufferManager::assignBufferIndex(
|
||||
ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer)
|
||||
{
|
||||
if (onlyUseDynamicBuffer)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
std::unique_lock lk(resource.mBuffersMutex);
|
||||
resource.mBuffersCV.wait(
|
||||
lk, [&resource, bufferCount]() { return static_cast<size_t>(resource.mConcurrence) < bufferCount; });
|
||||
int bufferId = -1;
|
||||
for (size_t i = 0; i < bufferCount; i++)
|
||||
{
|
||||
if (resource.mBufferIndexFlag[i] == 0)
|
||||
{
|
||||
bufferId = i;
|
||||
resource.mBufferIndexFlag[bufferId] = 1;
|
||||
resource.mConcurrence++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TLLM_CHECK_WITH_INFO(bufferId >= 0 && static_cast<size_t>(bufferId) < bufferCount,
|
||||
" assignBufferIndex: Buffer index already assigned");
|
||||
|
||||
return bufferId;
|
||||
}
|
||||
|
||||
void CacheTransBufferManager::freeBufferIndex(
|
||||
ConcurrenceResource& resource, std::optional<int> bufferId, size_t bufferCount, bool onlyUseDynamicBuffer)
|
||||
{
|
||||
if (onlyUseDynamicBuffer)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if (bufferId.has_value())
|
||||
{
|
||||
|
||||
TLLM_CHECK(static_cast<size_t>(bufferId.value()) < bufferCount);
|
||||
{
|
||||
std::scoped_lock lk(resource.mBuffersMutex);
|
||||
resource.mBufferIndexFlag[bufferId.value()] = 0;
|
||||
}
|
||||
resource.mConcurrence--;
|
||||
resource.mBuffersCV.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
size_t CacheTransBufferManager::getRecvBufferCount()
|
||||
{
|
||||
return mRecvBufferCount;
|
||||
}
|
||||
|
||||
size_t CacheTransBufferManager::getSendBufferCount()
|
||||
{
|
||||
return mSendBufferCount;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -17,13 +17,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/baseTransBuffer.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
@ -54,7 +57,9 @@ private:
|
||||
std::unique_ptr<Impl> pImpl;
|
||||
};
|
||||
|
||||
class CacheTransBufferManager
|
||||
/// @brief KV Cache specific transfer buffer manager.
|
||||
/// Inherits common buffer management from BaseTransBufferManager.
|
||||
class CacheTransBufferManager : public BaseTransBufferManager
|
||||
{
|
||||
public:
|
||||
CacheTransBufferManager(KVCacheManager::BaseKVCacheManager* cacheManager,
|
||||
@ -64,62 +69,18 @@ public:
|
||||
SizeType32 tokensPerBlock,
|
||||
std::optional<executor::CacheTransceiverConfig> const& cacheTransceiverConfig = std::nullopt);
|
||||
|
||||
std::optional<int> assignBufferIndexForSend();
|
||||
void freeBufferIndexForSend(std::optional<int> bufferId);
|
||||
std::optional<int> assignBufferIndexForRecv();
|
||||
void freeBufferIndexForRecv(std::optional<int> bufferId);
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateSendBuffers(
|
||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse);
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateRecvBuffers(
|
||||
std::optional<int> bufferId, int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse);
|
||||
|
||||
runtime::ITensor::SharedPtr getSendBuffer(std::optional<int> bufferId);
|
||||
runtime::ITensor::SharedPtr getRecvBuffer(std::optional<int> bufferId);
|
||||
size_t getRecvBufferCount();
|
||||
size_t getSendBufferCount();
|
||||
|
||||
std::optional<size_t> getMaxNumTokens()
|
||||
/// @brief Get the KV cache manager.
|
||||
[[nodiscard]] KVCacheManager::BaseKVCacheManager* getCacheManager() const noexcept
|
||||
{
|
||||
return mMaxNumTokens;
|
||||
return mCacheManager;
|
||||
}
|
||||
|
||||
private:
|
||||
struct ConcurrenceResource
|
||||
{
|
||||
std::unordered_map<int, runtime::ITensor::SharedPtr> mBuffers;
|
||||
std::vector<int> mBufferIndexFlag;
|
||||
std::mutex mBuffersMutex;
|
||||
std::condition_variable mBuffersCV;
|
||||
std::atomic<int> mConcurrence = 0;
|
||||
};
|
||||
/// @brief Compute transfer buffer size from KV cache configuration.
|
||||
static size_t computeTransferBufferSize(KVCacheManager::BaseKVCacheManager* cacheManager,
|
||||
std::optional<size_t> maxNumTokens, bool transferIndexerKCache);
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedPtr>, size_t, bool> getOrAllocateBuffers(std::optional<int> bufferId,
|
||||
int targetNum, std::vector<size_t> const& requestedNumberOfElements,
|
||||
runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource);
|
||||
|
||||
void allocateBuffer();
|
||||
std::optional<int> assignBufferIndex(ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer);
|
||||
void freeBufferIndex(
|
||||
ConcurrenceResource& resource, std::optional<int> bufferId, size_t bufferCount, bool onlyUseDynamicBuffer);
|
||||
|
||||
size_t mPreAllocBufferSize;
|
||||
size_t mRecvBufferCount;
|
||||
size_t mSendBufferCount;
|
||||
size_t mTransferBufferSize;
|
||||
bool mOnlyUseDynamicBuffer;
|
||||
bool mUseFabricMemory;
|
||||
size_t mNumberOfElements;
|
||||
nvinfer1::DataType mDataType;
|
||||
ConcurrenceResource mConcurrenceSendResource;
|
||||
ConcurrenceResource mConcurrenceRecvResource;
|
||||
KVCacheManager::BaseKVCacheManager* mCacheManager;
|
||||
runtime::BufferManager mBufferManager;
|
||||
std::vector<std::unique_ptr<FabricMemory>> mFabricMemory;
|
||||
std::optional<size_t> mMaxNumTokens;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
Loading…
Reference in New Issue
Block a user