/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "baseTransBuffer.h" #include "cacheTransBuffer.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/opUtils.h" #include namespace tensorrt_llm::batch_manager { BaseTransBufferManager::BaseTransBufferManager( size_t transferBufferSize, nvinfer1::DataType dataType, std::optional maxNumTokens) : mDataType{dataType} , mBufferManager{std::make_shared()} , mMaxNumTokens{maxNumTokens} { mTransferBufferSize = transferBufferSize; mOnlyUseDynamicBuffer = mTransferBufferSize == 0; mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1; mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum(); mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer()) && kv_cache_manager::FabricMemory::supportFbaricMemory(); if (mUseFabricMemory) { mTransferBufferSize = kv_cache_manager::FabricMemory::getAlignedSize(mTransferBufferSize); } mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount); TLLM_LOG_INFO( "BaseTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, " "mSendBufferCount:%ld, mTransferBufferSize:%ld, mPreAllocBufferSize:%ld, mOnlyUseDynamicBuffer:%d, " "mUseFabricMemory:%d, mDataType:%d", maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize, mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, static_cast(mDataType)); allocateBuffer(); } std::optional BaseTransBufferManager::assignBufferIndexForSend() { return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer); } void BaseTransBufferManager::freeBufferIndexForSend(std::optional bufferId) { freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer); } std::optional BaseTransBufferManager::assignBufferIndexForRecv() { return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer); } void BaseTransBufferManager::freeBufferIndexForRecv(std::optional bufferId) { freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer); } std::tuple, size_t, bool> BaseTransBufferManager::getOrAllocateSendBuffers( std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, runtime::BufferManager const& bufferManagerToUse) { return getOrAllocateBuffers( bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceSendResource); } std::tuple, size_t, bool> BaseTransBufferManager::getOrAllocateRecvBuffers( std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, runtime::BufferManager const& bufferManagerToUse) { return getOrAllocateBuffers( bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceRecvResource); } runtime::ITensor::SharedPtr BaseTransBufferManager::getSendBuffer(std::optional bufferId) { TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < mSendBufferCount); return mConcurrenceSendResource.mBuffers[bufferId.value()]; } return nullptr; } runtime::ITensor::SharedPtr BaseTransBufferManager::getRecvBuffer(std::optional bufferId) { TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < mRecvBufferCount); // TLLM_CHECK(mConcurrenceRecvResource.mBufferIndexFlag[bufferId.value()] == 1); return mConcurrenceRecvResource.mBuffers[bufferId.value()]; } return nullptr; } std::tuple, size_t, bool> BaseTransBufferManager::getOrAllocateBuffers( std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource) { TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); TLLM_CHECK(requestedNumberOfElements.size() >= static_cast(targetNum)); std::vector retSplitCaches; size_t bufferCoverTargetNum = 0; if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < concurrenceResource.mBuffers.size()); TLLM_CHECK(concurrenceResource.mBufferIndexFlag[bufferId.value()] == 1); size_t preBufferEleSize = 0; for (int i = 0; i < targetNum; i++) { // Strict checking. if (preBufferEleSize + requestedNumberOfElements[i] <= mNumberOfElements) { auto slice = runtime::ITensor::slice( concurrenceResource.mBuffers[bufferId.value()], preBufferEleSize, requestedNumberOfElements[i]); preBufferEleSize += requestedNumberOfElements[i]; bufferCoverTargetNum++; retSplitCaches.push_back(std::move(slice)); } else { retSplitCaches.push_back(bufferManagerToUse.gpu( runtime::ITensor::makeShape({static_cast(requestedNumberOfElements[i])}), mDataType)); } } TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum); if (bufferCoverTargetNum < static_cast(targetNum)) { TLLM_LOG_WARNING( "CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic " "buffer which will fail with NIXL backend. It is recommended to set " "cacheTransceiverConfig.MaxTokensInBuffer (cache_transceiver_config.max_tokens_in_buffer in config " "YAML file) to a value greater than the maximum ISL of the processed requests. Otherwise, performance " "may be degraded or transfer may fail. requestedNumberOfElements.size():%ld, " "mNumberOfElements:%ld, requestedNumberOfElements[0]:%ld", bufferCoverTargetNum, targetNum, requestedNumberOfElements.size(), mNumberOfElements, requestedNumberOfElements[0]); } } else { for (int i = 0; i < targetNum; i++) { retSplitCaches.push_back(bufferManagerToUse.gpu( runtime::ITensor::makeShape({static_cast(requestedNumberOfElements[i])}), mDataType)); } bufferCoverTargetNum = targetNum; } return std::make_tuple(retSplitCaches, bufferCoverTargetNum, mOnlyUseDynamicBuffer); } void BaseTransBufferManager::allocateBuffer() { if (mOnlyUseDynamicBuffer) { return; } mNumberOfElements = mTransferBufferSize / common::getDTypeSize(mDataType); mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0); mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0); if (mUseFabricMemory) { mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount); for (size_t i = 0; i < mSendBufferCount; i++) { mFabricMemory.emplace_back(std::make_unique(mTransferBufferSize)); mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType, runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mNumberOfElements); } for (size_t i = 0; i < mRecvBufferCount; i++) { mFabricMemory.emplace_back(std::make_unique(mTransferBufferSize)); mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType, runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mNumberOfElements); } } else if (common::getEnvKVCacheTransferUseAsyncBuffer()) { for (size_t i = 0; i < mSendBufferCount; i++) { mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); } for (size_t i = 0; i < mRecvBufferCount; i++) { mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); } mBufferManager.getStream().synchronize(); } else { for (size_t i = 0; i < mSendBufferCount; i++) { mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync( runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); } for (size_t i = 0; i < mRecvBufferCount; i++) { mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync( runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); } } } std::optional BaseTransBufferManager::assignBufferIndex( ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer) { if (onlyUseDynamicBuffer) { return std::nullopt; } std::unique_lock lk(resource.mBuffersMutex); resource.mBuffersCV.wait( lk, [&resource, bufferCount]() { return static_cast(resource.mConcurrence) < bufferCount; }); int bufferId = -1; for (size_t i = 0; i < bufferCount; i++) { if (resource.mBufferIndexFlag[i] == 0) { bufferId = i; resource.mBufferIndexFlag[bufferId] = 1; resource.mConcurrence++; break; } } TLLM_CHECK_WITH_INFO(bufferId >= 0 && static_cast(bufferId) < bufferCount, " assignBufferIndex: Buffer index already assigned"); return bufferId; } void BaseTransBufferManager::freeBufferIndex( ConcurrenceResource& resource, std::optional bufferId, size_t bufferCount, bool onlyUseDynamicBuffer) { if (onlyUseDynamicBuffer) { return; } if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < bufferCount); { std::scoped_lock lk(resource.mBuffersMutex); resource.mBufferIndexFlag[bufferId.value()] = 0; } resource.mConcurrence--; resource.mBuffersCV.notify_one(); } } size_t BaseTransBufferManager::getRecvBufferCount() { return mRecvBufferCount; } size_t BaseTransBufferManager::getSendBufferCount() { return mSendBufferCount; } } // namespace tensorrt_llm::batch_manager