diff --git a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt index f62c2aaf7f..95b4f1c8ac 100644 --- a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt +++ b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt @@ -21,6 +21,7 @@ set(TARGET_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(SRCS allocateKvCache.cpp assignReqSeqSlots.cpp + baseTransBuffer.cpp cacheFormatter.cpp mlaCacheFormatter.cpp cacheTransceiver.cpp diff --git a/cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp b/cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp new file mode 100644 index 0000000000..e091295788 --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/baseTransBuffer.cpp @@ -0,0 +1,285 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "baseTransBuffer.h" +#include "cacheTransBuffer.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/opUtils.h" + +#include + +namespace tensorrt_llm::batch_manager +{ + +BaseTransBufferManager::BaseTransBufferManager( + size_t transferBufferSize, nvinfer1::DataType dataType, std::optional maxNumTokens) + : mDataType{dataType} + , mBufferManager{std::make_shared()} + , mMaxNumTokens{maxNumTokens} +{ + mTransferBufferSize = transferBufferSize; + mOnlyUseDynamicBuffer = mTransferBufferSize == 0; + mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1; + mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum(); + mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer()) + && kv_cache_manager::FabricMemory::supportFbaricMemory(); + if (mUseFabricMemory) + { + mTransferBufferSize = kv_cache_manager::FabricMemory::getAlignedSize(mTransferBufferSize); + } + mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount); + + TLLM_LOG_INFO( + "BaseTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, " + "mSendBufferCount:%ld, mTransferBufferSize:%ld, mPreAllocBufferSize:%ld, mOnlyUseDynamicBuffer:%d, " + "mUseFabricMemory:%d, mDataType:%d", + maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize, + mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, static_cast(mDataType)); + + allocateBuffer(); +} + +std::optional BaseTransBufferManager::assignBufferIndexForSend() +{ + return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer); +} + +void BaseTransBufferManager::freeBufferIndexForSend(std::optional bufferId) +{ + freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer); +} + +std::optional BaseTransBufferManager::assignBufferIndexForRecv() +{ + return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer); +} + +void BaseTransBufferManager::freeBufferIndexForRecv(std::optional bufferId) +{ + freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer); +} + +std::tuple, size_t, bool> BaseTransBufferManager::getOrAllocateSendBuffers( + std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, + runtime::BufferManager const& bufferManagerToUse) +{ + return getOrAllocateBuffers( + bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceSendResource); +} + +std::tuple, size_t, bool> BaseTransBufferManager::getOrAllocateRecvBuffers( + std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, + runtime::BufferManager const& bufferManagerToUse) +{ + return getOrAllocateBuffers( + bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceRecvResource); +} + +runtime::ITensor::SharedPtr BaseTransBufferManager::getSendBuffer(std::optional bufferId) +{ + TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); + if (bufferId.has_value()) + { + TLLM_CHECK(static_cast(bufferId.value()) < mSendBufferCount); + return mConcurrenceSendResource.mBuffers[bufferId.value()]; + } + return nullptr; +} + +runtime::ITensor::SharedPtr BaseTransBufferManager::getRecvBuffer(std::optional bufferId) +{ + TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); + if (bufferId.has_value()) + { + TLLM_CHECK(static_cast(bufferId.value()) < mRecvBufferCount); + // TLLM_CHECK(mConcurrenceRecvResource.mBufferIndexFlag[bufferId.value()] == 1); + return mConcurrenceRecvResource.mBuffers[bufferId.value()]; + } + return nullptr; +} + +std::tuple, size_t, bool> BaseTransBufferManager::getOrAllocateBuffers( + std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, + runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource) +{ + TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); + TLLM_CHECK(requestedNumberOfElements.size() >= static_cast(targetNum)); + std::vector retSplitCaches; + + size_t bufferCoverTargetNum = 0; + + if (bufferId.has_value()) + { + TLLM_CHECK(static_cast(bufferId.value()) < concurrenceResource.mBuffers.size()); + TLLM_CHECK(concurrenceResource.mBufferIndexFlag[bufferId.value()] == 1); + size_t preBufferEleSize = 0; + for (int i = 0; i < targetNum; i++) + { + // Strict checking. + if (preBufferEleSize + requestedNumberOfElements[i] <= mNumberOfElements) + { + auto slice = runtime::ITensor::slice( + concurrenceResource.mBuffers[bufferId.value()], preBufferEleSize, requestedNumberOfElements[i]); + preBufferEleSize += requestedNumberOfElements[i]; + bufferCoverTargetNum++; + retSplitCaches.push_back(std::move(slice)); + } + else + { + retSplitCaches.push_back(bufferManagerToUse.gpu( + runtime::ITensor::makeShape({static_cast(requestedNumberOfElements[i])}), mDataType)); + } + } + TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum); + if (bufferCoverTargetNum < static_cast(targetNum)) + { + TLLM_LOG_WARNING( + "CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic " + "buffer which will fail with NIXL backend. It is recommended to set " + "cacheTransceiverConfig.MaxTokensInBuffer (cache_transceiver_config.max_tokens_in_buffer in config " + "YAML file) to a value greater than the maximum ISL of the processed requests. Otherwise, performance " + "may be degraded or transfer may fail. requestedNumberOfElements.size():%ld, " + "mNumberOfElements:%ld, requestedNumberOfElements[0]:%ld", + bufferCoverTargetNum, targetNum, requestedNumberOfElements.size(), mNumberOfElements, + requestedNumberOfElements[0]); + } + } + else + { + for (int i = 0; i < targetNum; i++) + { + retSplitCaches.push_back(bufferManagerToUse.gpu( + runtime::ITensor::makeShape({static_cast(requestedNumberOfElements[i])}), mDataType)); + } + bufferCoverTargetNum = targetNum; + } + + return std::make_tuple(retSplitCaches, bufferCoverTargetNum, mOnlyUseDynamicBuffer); +} + +void BaseTransBufferManager::allocateBuffer() +{ + if (mOnlyUseDynamicBuffer) + { + return; + } + mNumberOfElements = mTransferBufferSize / common::getDTypeSize(mDataType); + mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0); + mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0); + if (mUseFabricMemory) + { + mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount); + for (size_t i = 0; i < mSendBufferCount; i++) + { + mFabricMemory.emplace_back(std::make_unique(mTransferBufferSize)); + mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType, + runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mNumberOfElements); + } + for (size_t i = 0; i < mRecvBufferCount; i++) + { + mFabricMemory.emplace_back(std::make_unique(mTransferBufferSize)); + mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType, + runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mNumberOfElements); + } + } + else if (common::getEnvKVCacheTransferUseAsyncBuffer()) + { + for (size_t i = 0; i < mSendBufferCount; i++) + { + mConcurrenceSendResource.mBuffers[i] + = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); + } + for (size_t i = 0; i < mRecvBufferCount; i++) + { + mConcurrenceRecvResource.mBuffers[i] + = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); + } + mBufferManager.getStream().synchronize(); + } + else + { + for (size_t i = 0; i < mSendBufferCount; i++) + { + mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync( + runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); + } + for (size_t i = 0; i < mRecvBufferCount; i++) + { + mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync( + runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); + } + } +} + +std::optional BaseTransBufferManager::assignBufferIndex( + ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer) +{ + if (onlyUseDynamicBuffer) + { + return std::nullopt; + } + std::unique_lock lk(resource.mBuffersMutex); + resource.mBuffersCV.wait( + lk, [&resource, bufferCount]() { return static_cast(resource.mConcurrence) < bufferCount; }); + int bufferId = -1; + for (size_t i = 0; i < bufferCount; i++) + { + if (resource.mBufferIndexFlag[i] == 0) + { + bufferId = i; + resource.mBufferIndexFlag[bufferId] = 1; + resource.mConcurrence++; + break; + } + } + TLLM_CHECK_WITH_INFO(bufferId >= 0 && static_cast(bufferId) < bufferCount, + " assignBufferIndex: Buffer index already assigned"); + + return bufferId; +} + +void BaseTransBufferManager::freeBufferIndex( + ConcurrenceResource& resource, std::optional bufferId, size_t bufferCount, bool onlyUseDynamicBuffer) +{ + if (onlyUseDynamicBuffer) + { + return; + } + if (bufferId.has_value()) + { + TLLM_CHECK(static_cast(bufferId.value()) < bufferCount); + { + std::scoped_lock lk(resource.mBuffersMutex); + resource.mBufferIndexFlag[bufferId.value()] = 0; + } + resource.mConcurrence--; + resource.mBuffersCV.notify_one(); + } +} + +size_t BaseTransBufferManager::getRecvBufferCount() +{ + return mRecvBufferCount; +} + +size_t BaseTransBufferManager::getSendBufferCount() +{ + return mSendBufferCount; +} + +} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/baseTransBuffer.h b/cpp/tensorrt_llm/batch_manager/baseTransBuffer.h new file mode 100644 index 0000000000..ec311e5c40 --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/baseTransBuffer.h @@ -0,0 +1,144 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/iTensor.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::batch_manager::kv_cache_manager +{ +class FabricMemory; +} // namespace tensorrt_llm::batch_manager::kv_cache_manager + +namespace tensorrt_llm::batch_manager +{ + +/// @brief Base class for cache transfer buffer management. +/// Handles buffer pool allocation, index assignment, and slicing. +/// Derived classes provide cache-specific size calculations. +class BaseTransBufferManager +{ +public: + virtual ~BaseTransBufferManager() = default; + + /// @brief Assign a buffer index for sending. + /// @return Assigned buffer index, or nullopt if using dynamic buffers. + std::optional assignBufferIndexForSend(); + + /// @brief Free a buffer index used for sending. + /// @param bufferId The buffer index to free. + void freeBufferIndexForSend(std::optional bufferId); + + /// @brief Assign a buffer index for receiving. + /// @return Assigned buffer index, or nullopt if using dynamic buffers. + std::optional assignBufferIndexForRecv(); + + /// @brief Free a buffer index used for receiving. + /// @param bufferId The buffer index to free. + void freeBufferIndexForRecv(std::optional bufferId); + + /// @brief Get or allocate send buffers for cache transfer. + /// @param bufferId The assigned buffer ID. + /// @param targetNum Number of target sequences. + /// @param requestedNumberOfElements Sizes requested for each target. + /// @param bufferManagerToUse Buffer manager for dynamic allocation. + /// @return Tuple of (buffers, covered target count, is dynamic only). + std::tuple, size_t, bool> getOrAllocateSendBuffers( + std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, + runtime::BufferManager const& bufferManagerToUse); + + /// @brief Get or allocate receive buffers for cache transfer. + /// @param bufferId The assigned buffer ID. + /// @param targetNum Number of target sequences. + /// @param requestedNumberOfElements Sizes requested for each target. + /// @param bufferManagerToUse Buffer manager for dynamic allocation. + /// @return Tuple of (buffers, covered target count, is dynamic only). + std::tuple, size_t, bool> getOrAllocateRecvBuffers( + std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, + runtime::BufferManager const& bufferManagerToUse); + + /// @brief Get the send buffer for a given buffer ID. + runtime::ITensor::SharedPtr getSendBuffer(std::optional bufferId); + + /// @brief Get the receive buffer for a given buffer ID. + runtime::ITensor::SharedPtr getRecvBuffer(std::optional bufferId); + + /// @brief Get the number of receive buffers. + size_t getRecvBufferCount(); + + /// @brief Get the number of send buffers. + size_t getSendBufferCount(); + + /// @brief Get the maximum number of tokens configured. + std::optional getMaxNumTokens() + { + return mMaxNumTokens; + } + +protected: + /// @brief Constructor - derived classes call this after computing buffer sizes. + /// @param transferBufferSize Size of each transfer buffer in bytes. + /// @param dataType Data type for the buffers. + /// @param maxNumTokens Optional max tokens for sizing. + BaseTransBufferManager( + size_t transferBufferSize, nvinfer1::DataType dataType, std::optional maxNumTokens = std::nullopt); + + struct ConcurrenceResource + { + std::unordered_map mBuffers; + std::vector mBufferIndexFlag; + std::mutex mBuffersMutex; + std::condition_variable mBuffersCV; + std::atomic mConcurrence{0}; + }; + + std::tuple, size_t, bool> getOrAllocateBuffers(std::optional bufferId, + int targetNum, std::vector const& requestedNumberOfElements, + runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource); + + void allocateBuffer(); + std::optional assignBufferIndex(ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer); + void freeBufferIndex( + ConcurrenceResource& resource, std::optional bufferId, size_t bufferCount, bool onlyUseDynamicBuffer); + + size_t mPreAllocBufferSize; + size_t mRecvBufferCount; + size_t mSendBufferCount; + size_t mTransferBufferSize; + bool mOnlyUseDynamicBuffer; + bool mUseFabricMemory; + size_t mNumberOfElements; + nvinfer1::DataType mDataType; + ConcurrenceResource mConcurrenceSendResource; + ConcurrenceResource mConcurrenceRecvResource; + runtime::BufferManager mBufferManager; + std::vector> mFabricMemory; + std::optional mMaxNumTokens; +}; + +} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp index 347325d4a1..fca4419f22 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @@ -20,12 +20,17 @@ #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/executor/executor.h" + #include #include namespace tensorrt_llm::batch_manager::kv_cache_manager { +// ============================================================================ +// FabricMemory Implementation +// ============================================================================ + class FabricMemory::Impl { public: @@ -182,45 +187,46 @@ bool FabricMemory::supportFbaricMemory() #endif } -CacheTransBufferManager::CacheTransBufferManager( +// ============================================================================ +// CacheTransBufferManager Implementation +// ============================================================================ + +size_t CacheTransBufferManager::computeTransferBufferSize( KVCacheManager::BaseKVCacheManager* cacheManager, std::optional maxNumTokens, bool transferIndexerKCache) - : mCacheManager{cacheManager} - , mBufferManager{std::make_shared()} - , mMaxNumTokens{maxNumTokens} { - // TODO: FP4 dataSize - TLLM_CHECK(mCacheManager); + nvinfer1::DataType dataType; if (transferIndexerKCache) { - mDataType = mCacheManager->getIndexerKCachePool()->getDataType(); + dataType = cacheManager->getIndexerKCachePool()->getDataType(); } else { - mDataType = mCacheManager->getPrimaryPool(0)->getDataType(); + dataType = cacheManager->getPrimaryPool(0)->getDataType(); } - auto tokensPerBlock = mCacheManager->getBlockManager().getTokensPerBlock(); + auto tokensPerBlock = cacheManager->getBlockManager().getTokensPerBlock(); size_t bufferSizeFromMaxNumToken = 0; + if (maxNumTokens.has_value()) { TLLM_CHECK(maxNumTokens.value() % tokensPerBlock == 0); - auto dataSize = common::getDTypeSize(mDataType); + auto dataSize = common::getDTypeSize(dataType); SizeType32 kvCacheByteSizePerTokenPerLayer = 0; if (transferIndexerKCache) { kvCacheByteSizePerTokenPerLayer - = mCacheManager->getIndexerKCachePool()->getDimension<-1>() * dataSize / tokensPerBlock; + = cacheManager->getIndexerKCachePool()->getDimension<-1>() * dataSize / tokensPerBlock; } else { - auto primaryPool = mCacheManager->getPrimaryPool(0); + auto primaryPool = cacheManager->getPrimaryPool(0); kvCacheByteSizePerTokenPerLayer = primaryPool->getDimension<-1>() * primaryPool->getDimension<2>() * dataSize / tokensPerBlock; } - for (auto layerId = 0; layerId < mCacheManager->getBlockManager().getNumLayers(); layerId++) + for (auto layerId = 0; layerId < cacheManager->getBlockManager().getNumLayers(); layerId++) { - auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId); - auto windowSize = static_cast(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx)); + auto poolIdx = cacheManager->getBlockManager().getLayerPoolIdx(layerId); + auto windowSize = static_cast(cacheManager->getBlockManager().getPoolWindowSize(poolIdx)); auto alignedWindowSize = (windowSize + tokensPerBlock - 1) / tokensPerBlock * tokensPerBlock; auto validTokenNum = (alignedWindowSize < maxNumTokens.value() ? alignedWindowSize : maxNumTokens.value()); if (common::getEnvKVCacheTransferAllBlocksForWindow()) @@ -233,26 +239,20 @@ CacheTransBufferManager::CacheTransBufferManager( } } - mTransferBufferSize - = maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer(); - mOnlyUseDynamicBuffer = mTransferBufferSize == 0; - mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1; - mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum(); - mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer()) - && FabricMemory::supportFbaricMemory(); - if (mUseFabricMemory) - { - mTransferBufferSize = FabricMemory::getAlignedSize(mTransferBufferSize); - } - mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount); - TLLM_LOG_INFO( - "CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, " - "mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d " - "mUseFabricMemory:%d mDataType:%d", - maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize, - mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, mDataType); + return maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer(); +} - allocateBuffer(); +CacheTransBufferManager::CacheTransBufferManager( + KVCacheManager::BaseKVCacheManager* cacheManager, std::optional maxNumTokens, bool transferIndexerKCache) + : BaseTransBufferManager(computeTransferBufferSize(cacheManager, maxNumTokens, transferIndexerKCache), + transferIndexerKCache ? cacheManager->getIndexerKCachePool()->getDataType() + : cacheManager->getPrimaryPool(0)->getDataType(), + maxNumTokens) + , mCacheManager{cacheManager} +{ + // TODO: FP4 dataSize + TLLM_CHECK(mCacheManager); + TLLM_LOG_INFO("CacheTransBufferManager created for KV cache"); } size_t CacheTransBufferManager::preAllocBufferSize( @@ -298,233 +298,4 @@ size_t CacheTransBufferManager::preAllocBufferSize( return preAllocBufferSize; } -std::optional CacheTransBufferManager::assignBufferIndexForSend() -{ - return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer); -} - -void CacheTransBufferManager::freeBufferIndexForSend(std::optional bufferId) -{ - freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer); -} - -std::optional CacheTransBufferManager::assignBufferIndexForRecv() -{ - return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer); -} - -void CacheTransBufferManager::freeBufferIndexForRecv(std::optional bufferId) -{ - freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer); -} - -std::tuple, size_t, bool> CacheTransBufferManager::getOrAllocateSendBuffers( - std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, - runtime::BufferManager const& bufferManagerToUse) -{ - return getOrAllocateBuffers( - bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceSendResource); -} - -std::tuple, size_t, bool> CacheTransBufferManager::getOrAllocateRecvBuffers( - std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, - runtime::BufferManager const& bufferManagerToUse) -{ - return getOrAllocateBuffers( - bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceRecvResource); -} - -runtime::ITensor::SharedPtr CacheTransBufferManager::getSendBuffer(std::optional bufferId) -{ - TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); - if (bufferId.has_value()) - { - TLLM_CHECK(static_cast(bufferId.value()) < mSendBufferCount); - return mConcurrenceSendResource.mBuffers[bufferId.value()]; - } - return nullptr; -} - -runtime::ITensor::SharedPtr CacheTransBufferManager::getRecvBuffer(std::optional bufferId) -{ - TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); - if (bufferId.has_value()) - { - TLLM_CHECK(static_cast(bufferId.value()) < mRecvBufferCount); - // TLLM_CHECK(mConcurrenceRecvResource.mBufferIndexFlag[bufferId.value()] == 1); - return mConcurrenceRecvResource.mBuffers[bufferId.value()]; - } - return nullptr; -} - -std::tuple, size_t, bool> CacheTransBufferManager::getOrAllocateBuffers( - std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, - runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource) -{ - TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); - TLLM_CHECK(requestedNumberOfElements.size() >= static_cast(targetNum)); - std::vector retSplitCaches; - - size_t bufferCoverTargetNum = 0; - - if (bufferId.has_value()) - { - TLLM_CHECK(static_cast(bufferId.value()) < concurrenceResource.mBuffers.size()); - TLLM_CHECK(concurrenceResource.mBufferIndexFlag[bufferId.value()] == 1); - size_t preBufferEleSize = 0; - for (int i = 0; i < targetNum; i++) - { - // Strict checking. - if (preBufferEleSize + requestedNumberOfElements[i] <= mNumberOfElements) - { - auto slice = runtime::ITensor::slice( - concurrenceResource.mBuffers[bufferId.value()], preBufferEleSize, requestedNumberOfElements[i]); - preBufferEleSize += requestedNumberOfElements[i]; - bufferCoverTargetNum++; - retSplitCaches.push_back(std::move(slice)); - } - else - { - retSplitCaches.push_back(bufferManagerToUse.gpu( - runtime::ITensor::makeShape({static_cast(requestedNumberOfElements[i])}), mDataType)); - } - } - TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum); - if (bufferCoverTargetNum < static_cast(targetNum)) - { - TLLM_LOG_WARNING( - "CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic " - "buffer which will fail with NIXL backend. It is recommended to set " - "cacheTransceiverConfig.MaxTokensInBuffer (cache_transceiver_config.max_tokens_in_buffer in config " - "YAML file) to a value greater than the maximum ISL of the processed requests. Otherwise, performance " - "may be degraded or transfer may fail. requestedNumberOfElements.size():%ld, " - "mNumberOfElements:%ld, requestedNumberOfElements[0]:%ld", - bufferCoverTargetNum, targetNum, requestedNumberOfElements.size(), mNumberOfElements, - requestedNumberOfElements[0]); - } - } - else - { - for (int i = 0; i < targetNum; i++) - { - retSplitCaches.push_back(bufferManagerToUse.gpu( - runtime::ITensor::makeShape({static_cast(requestedNumberOfElements[i])}), mDataType)); - } - bufferCoverTargetNum = targetNum; - } - - return std::make_tuple(retSplitCaches, bufferCoverTargetNum, mOnlyUseDynamicBuffer); -} - -void CacheTransBufferManager::allocateBuffer() -{ - if (mOnlyUseDynamicBuffer) - { - return; - } - mNumberOfElements = mTransferBufferSize / common::getDTypeSize(mDataType); - mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0); - mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0); - if (mUseFabricMemory) - { - mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount); - for (size_t i = 0; i < mSendBufferCount; i++) - { - mFabricMemory.emplace_back(std::make_unique(mTransferBufferSize)); - mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType, - runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mNumberOfElements); - } - for (size_t i = 0; i < mRecvBufferCount; i++) - { - mFabricMemory.emplace_back(std::make_unique(mTransferBufferSize)); - mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType, - runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mNumberOfElements); - } - } - else if (common::getEnvKVCacheTransferUseAsyncBuffer()) - { - for (size_t i = 0; i < mSendBufferCount; i++) - { - mConcurrenceSendResource.mBuffers[i] - = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); - } - for (size_t i = 0; i < mRecvBufferCount; i++) - { - mConcurrenceRecvResource.mBuffers[i] - = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); - } - mBufferManager.getStream().synchronize(); - } - else - { - for (size_t i = 0; i < mSendBufferCount; i++) - { - mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync( - runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); - } - for (size_t i = 0; i < mRecvBufferCount; i++) - { - mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync( - runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); - } - } -} - -std::optional CacheTransBufferManager::assignBufferIndex( - ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer) -{ - if (onlyUseDynamicBuffer) - { - return std::nullopt; - } - std::unique_lock lk(resource.mBuffersMutex); - resource.mBuffersCV.wait( - lk, [&resource, bufferCount]() { return static_cast(resource.mConcurrence) < bufferCount; }); - int bufferId = -1; - for (size_t i = 0; i < bufferCount; i++) - { - if (resource.mBufferIndexFlag[i] == 0) - { - bufferId = i; - resource.mBufferIndexFlag[bufferId] = 1; - resource.mConcurrence++; - break; - } - } - TLLM_CHECK_WITH_INFO(bufferId >= 0 && static_cast(bufferId) < bufferCount, - " assignBufferIndex: Buffer index already assigned"); - - return bufferId; -} - -void CacheTransBufferManager::freeBufferIndex( - ConcurrenceResource& resource, std::optional bufferId, size_t bufferCount, bool onlyUseDynamicBuffer) -{ - if (onlyUseDynamicBuffer) - { - return; - } - if (bufferId.has_value()) - { - - TLLM_CHECK(static_cast(bufferId.value()) < bufferCount); - { - std::scoped_lock lk(resource.mBuffersMutex); - resource.mBufferIndexFlag[bufferId.value()] = 0; - } - resource.mConcurrence--; - resource.mBuffersCV.notify_one(); - } -} - -size_t CacheTransBufferManager::getRecvBufferCount() -{ - return mRecvBufferCount; -} - -size_t CacheTransBufferManager::getSendBufferCount() -{ - return mSendBufferCount; -} - } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h index 9e92914ec0..96c1314944 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h @@ -17,13 +17,16 @@ #pragma once +#include "tensorrt_llm/batch_manager/baseTransBuffer.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iTensor.h" + #include #include #include +#include #include #include #include @@ -54,7 +57,9 @@ private: std::unique_ptr pImpl; }; -class CacheTransBufferManager +/// @brief KV Cache specific transfer buffer manager. +/// Inherits common buffer management from BaseTransBufferManager. +class CacheTransBufferManager : public BaseTransBufferManager { public: CacheTransBufferManager(KVCacheManager::BaseKVCacheManager* cacheManager, @@ -64,62 +69,18 @@ public: SizeType32 tokensPerBlock, std::optional const& cacheTransceiverConfig = std::nullopt); - std::optional assignBufferIndexForSend(); - void freeBufferIndexForSend(std::optional bufferId); - std::optional assignBufferIndexForRecv(); - void freeBufferIndexForRecv(std::optional bufferId); - - std::tuple, size_t, bool> getOrAllocateSendBuffers( - std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, - runtime::BufferManager const& bufferManagerToUse); - - std::tuple, size_t, bool> getOrAllocateRecvBuffers( - std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, - runtime::BufferManager const& bufferManagerToUse); - - runtime::ITensor::SharedPtr getSendBuffer(std::optional bufferId); - runtime::ITensor::SharedPtr getRecvBuffer(std::optional bufferId); - size_t getRecvBufferCount(); - size_t getSendBufferCount(); - - std::optional getMaxNumTokens() + /// @brief Get the KV cache manager. + [[nodiscard]] KVCacheManager::BaseKVCacheManager* getCacheManager() const noexcept { - return mMaxNumTokens; + return mCacheManager; } private: - struct ConcurrenceResource - { - std::unordered_map mBuffers; - std::vector mBufferIndexFlag; - std::mutex mBuffersMutex; - std::condition_variable mBuffersCV; - std::atomic mConcurrence = 0; - }; + /// @brief Compute transfer buffer size from KV cache configuration. + static size_t computeTransferBufferSize(KVCacheManager::BaseKVCacheManager* cacheManager, + std::optional maxNumTokens, bool transferIndexerKCache); - std::tuple, size_t, bool> getOrAllocateBuffers(std::optional bufferId, - int targetNum, std::vector const& requestedNumberOfElements, - runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource); - - void allocateBuffer(); - std::optional assignBufferIndex(ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer); - void freeBufferIndex( - ConcurrenceResource& resource, std::optional bufferId, size_t bufferCount, bool onlyUseDynamicBuffer); - - size_t mPreAllocBufferSize; - size_t mRecvBufferCount; - size_t mSendBufferCount; - size_t mTransferBufferSize; - bool mOnlyUseDynamicBuffer; - bool mUseFabricMemory; - size_t mNumberOfElements; - nvinfer1::DataType mDataType; - ConcurrenceResource mConcurrenceSendResource; - ConcurrenceResource mConcurrenceRecvResource; KVCacheManager::BaseKVCacheManager* mCacheManager; - runtime::BufferManager mBufferManager; - std::vector> mFabricMemory; - std::optional mMaxNumTokens; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager