/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "cacheTransBuffer.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/opUtils.h" #include "tensorrt_llm/executor/executor.h" #include #include namespace tensorrt_llm::batch_manager::kv_cache_manager { class FabricMemory::Impl { public: Impl(size_t size) : mSize(size) { TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceIdx)); CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC; CUmemAllocationProp prop = {}; prop.requestedHandleTypes = handle_type; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = mDeviceIdx; prop.allocFlags.gpuDirectRDMACapable = 1; size_t granularity{0}; TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); mGranularity = granularity; mAllocSize = (size + granularity - 1) / granularity * granularity; TLLM_CU_CHECK(cuMemCreate(&mHandle, mAllocSize, &prop, 0)); TLLM_CU_CHECK(cuMemAddressReserve(&mDevicePtr, mAllocSize, mGranularity, 0, 0)); mPtr = reinterpret_cast(mDevicePtr); CUmemAccessDesc accessDesc = {}; accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; accessDesc.location.id = mDeviceIdx; TLLM_CU_CHECK(cuMemMap(mDevicePtr, mAllocSize, 0, mHandle, 0)); TLLM_CU_CHECK(cuMemSetAccess(mDevicePtr, mAllocSize, &accessDesc, 1)); TLLM_LOG_DEBUG("FabricMemory::Impl::Impl mAllocSize:%ld", mAllocSize); } ~Impl() { TLLM_LOG_DEBUG("FabricMemory::Impl::~Impl mAllocSize:%ld", mAllocSize); TLLM_CU_CHECK(cuMemUnmap(mDevicePtr, mAllocSize)); TLLM_CU_CHECK(cuMemRelease(mHandle)); TLLM_CU_CHECK(cuMemAddressFree(mDevicePtr, mAllocSize)); } [[nodiscard]] void* getPtr() const { return mPtr; } [[nodiscard]] size_t getSize() const { return mSize; } private: size_t mSize; size_t mAllocSize; size_t mGranularity; void* mPtr; CUdeviceptr mDevicePtr; CUmemGenericAllocationHandle mHandle; int mDeviceIdx; }; FabricMemory::FabricMemory(size_t size) : pImpl(std::make_unique(size)) { } FabricMemory::~FabricMemory() = default; FabricMemory::FabricMemory(FabricMemory&&) noexcept = default; FabricMemory& FabricMemory::operator=(FabricMemory&&) noexcept = default; void* FabricMemory::getPtr() const { return pImpl->getPtr(); } size_t FabricMemory::getSize() const { return pImpl->getSize(); } size_t FabricMemory::getAlignedSize(size_t size) { int deviceIdx = -1; TLLM_CUDA_CHECK(cudaGetDevice(&deviceIdx)); CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC; CUmemAllocationProp prop = {}; prop.requestedHandleTypes = handle_type; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = deviceIdx; prop.allocFlags.gpuDirectRDMACapable = 1; size_t granularity{0}; TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); return (size + granularity - 1) / granularity * granularity; } bool FabricMemory::supportFbaricMemory() { #ifdef __aarch64__ auto support_fun = []() { int fabric_handle_supported{0}; int gpu_direct_rdma_with_cuda_vmm_supported{0}; int deviceIdx = 0; TLLM_CUDA_CHECK(cudaGetDevice(&deviceIdx)); CUresult ret0 = cuDeviceGetAttribute( &fabric_handle_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, deviceIdx); CUresult ret1 = cuDeviceGetAttribute(&gpu_direct_rdma_with_cuda_vmm_supported, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED, deviceIdx); TLLM_LOG_DEBUG("FabricMemory::supportFabricMemory fabric_handle_supported:%d", fabric_handle_supported); TLLM_LOG_DEBUG("FabricMemory::supportFabricMemory gpu_direct_rdma_with_cuda_vmm_supported:%d", gpu_direct_rdma_with_cuda_vmm_supported); if (ret0 != CUresult::CUDA_SUCCESS || ret1 != CUresult::CUDA_SUCCESS || fabric_handle_supported == 0 || gpu_direct_rdma_with_cuda_vmm_supported == 0) { return false; } CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC; CUmemAllocationProp prop = {}; prop.requestedHandleTypes = handle_type; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = deviceIdx; prop.allocFlags.gpuDirectRDMACapable = 1; size_t granularity{0}; TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); CUmemGenericAllocationHandle handle; auto cuRet = cuMemCreate(&handle, granularity, &prop, 0); if (cuRet == CUresult::CUDA_SUCCESS) { TLLM_CU_CHECK(cuMemRelease(handle)); return true; } if (cuRet == CUresult::CUDA_ERROR_NOT_PERMITTED) { TLLM_LOG_WARNING("Try to creat fabric memory failed , setting imex channel may be required"); return false; } TLLM_CU_CHECK(cuRet); return false; }; static bool support = support_fun(); return support; #else return false; #endif } CacheTransBufferManager::CacheTransBufferManager( KVCacheManager::BaseKVCacheManager* cacheManager, std::optional maxNumTokens, bool transferIndexerKCache) : mCacheManager{cacheManager} , mBufferManager{std::make_shared()} , mMaxNumTokens{maxNumTokens} { // TODO: FP4 dataSize TLLM_CHECK(mCacheManager); if (transferIndexerKCache) { mDataType = mCacheManager->getIndexerKCachePool()->getDataType(); } else { mDataType = mCacheManager->getPrimaryPool(0)->getDataType(); } auto tokensPerBlock = mCacheManager->getBlockManager().getTokensPerBlock(); size_t bufferSizeFromMaxNumToken = 0; if (maxNumTokens.has_value()) { TLLM_CHECK(maxNumTokens.value() % tokensPerBlock == 0); auto dataSize = common::getDTypeSize(mDataType); SizeType32 kvCacheByteSizePerTokenPerLayer = 0; if (transferIndexerKCache) { kvCacheByteSizePerTokenPerLayer = mCacheManager->getIndexerKCachePool()->getDimension<-1>() * dataSize / tokensPerBlock; } else { auto primaryPool = mCacheManager->getPrimaryPool(0); kvCacheByteSizePerTokenPerLayer = primaryPool->getDimension<-1>() * primaryPool->getDimension<2>() * dataSize / tokensPerBlock; } for (auto layerId = 0; layerId < mCacheManager->getBlockManager().getNumLayers(); layerId++) { auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId); auto windowSize = static_cast(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx)); auto alignedWindowSize = (windowSize + tokensPerBlock - 1) / tokensPerBlock * tokensPerBlock; auto validTokenNum = (alignedWindowSize < maxNumTokens.value() ? alignedWindowSize : maxNumTokens.value()); if (common::getEnvKVCacheTransferAllBlocksForWindow()) { validTokenNum = maxNumTokens.value(); } validTokenNum += tokensPerBlock; // add one more block bufferSizeFromMaxNumToken += validTokenNum * kvCacheByteSizePerTokenPerLayer; } } mTransferBufferSize = maxNumTokens.has_value() ? bufferSizeFromMaxNumToken : common::getEnvMemSizeForKVCacheTransferBuffer(); mOnlyUseDynamicBuffer = mTransferBufferSize == 0; mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1; mSendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum(); mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer()) && FabricMemory::supportFbaricMemory(); if (mUseFabricMemory) { mTransferBufferSize = FabricMemory::getAlignedSize(mTransferBufferSize); } mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount); TLLM_LOG_INFO( "CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, " "mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d " "mUseFabricMemory:%d mDataType:%d", maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize, mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, mDataType); allocateBuffer(); } size_t CacheTransBufferManager::preAllocBufferSize( std::map const& cacheSizeBytesPerTokenPerWindow, SizeType32 tokensPerBlock, std::optional const& cacheTransceiverConfig) { if (!cacheTransceiverConfig.has_value()) { return 0; } if (!cacheTransceiverConfig->getBackendType().has_value()) { return 0; } auto maxNumTokens = cacheTransceiverConfig->getMaxTokensInBuffer(); size_t transferBufferSize = common::getEnvMemSizeForKVCacheTransferBuffer(); if (maxNumTokens.has_value()) { transferBufferSize = 0; for (auto const& [windowSize, cacheSizeBytesPerToken] : cacheSizeBytesPerTokenPerWindow) { auto alignedWindowSize = (windowSize + tokensPerBlock - 1) / tokensPerBlock * tokensPerBlock; auto validTokenNum = (static_cast(alignedWindowSize) < maxNumTokens.value() ? static_cast(alignedWindowSize) : maxNumTokens.value()); if (common::getEnvKVCacheTransferAllBlocksForWindow()) { validTokenNum = maxNumTokens.value(); } validTokenNum += tokensPerBlock; // add one more block transferBufferSize += validTokenNum * cacheSizeBytesPerToken; } } bool useFabricMemory = FabricMemory::supportFbaricMemory() && (!(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())); if (useFabricMemory) { transferBufferSize = FabricMemory::getAlignedSize(transferBufferSize); } size_t recvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1; size_t sendBufferCount = common::getEnvKVCacheSendMaxConcurrenceNum(); size_t preAllocBufferSize = transferBufferSize * (recvBufferCount + sendBufferCount); return preAllocBufferSize; } std::optional CacheTransBufferManager::assignBufferIndexForSend() { return assignBufferIndex(mConcurrenceSendResource, mSendBufferCount, mOnlyUseDynamicBuffer); } void CacheTransBufferManager::freeBufferIndexForSend(std::optional bufferId) { freeBufferIndex(mConcurrenceSendResource, bufferId, mSendBufferCount, mOnlyUseDynamicBuffer); } std::optional CacheTransBufferManager::assignBufferIndexForRecv() { return assignBufferIndex(mConcurrenceRecvResource, mRecvBufferCount, mOnlyUseDynamicBuffer); } void CacheTransBufferManager::freeBufferIndexForRecv(std::optional bufferId) { freeBufferIndex(mConcurrenceRecvResource, bufferId, mRecvBufferCount, mOnlyUseDynamicBuffer); } std::tuple, size_t, bool> CacheTransBufferManager::getOrAllocateSendBuffers( std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, runtime::BufferManager const& bufferManagerToUse) { return getOrAllocateBuffers( bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceSendResource); } std::tuple, size_t, bool> CacheTransBufferManager::getOrAllocateRecvBuffers( std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, runtime::BufferManager const& bufferManagerToUse) { return getOrAllocateBuffers( bufferId, targetNum, requestedNumberOfElements, bufferManagerToUse, mConcurrenceRecvResource); } runtime::ITensor::SharedPtr CacheTransBufferManager::getSendBuffer(std::optional bufferId) { TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < mSendBufferCount); return mConcurrenceSendResource.mBuffers[bufferId.value()]; } return nullptr; } runtime::ITensor::SharedPtr CacheTransBufferManager::getRecvBuffer(std::optional bufferId) { TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < mRecvBufferCount); // TLLM_CHECK(mConcurrenceRecvResource.mBufferIndexFlag[bufferId.value()] == 1); return mConcurrenceRecvResource.mBuffers[bufferId.value()]; } return nullptr; } std::tuple, size_t, bool> CacheTransBufferManager::getOrAllocateBuffers( std::optional bufferId, int targetNum, std::vector const& requestedNumberOfElements, runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource) { TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); TLLM_CHECK(requestedNumberOfElements.size() >= static_cast(targetNum)); std::vector retSplitCaches; size_t bufferCoverTargetNum = 0; if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < concurrenceResource.mBuffers.size()); TLLM_CHECK(concurrenceResource.mBufferIndexFlag[bufferId.value()] == 1); size_t preBufferEleSize = 0; for (int i = 0; i < targetNum; i++) { // Strict checking. if (preBufferEleSize + requestedNumberOfElements[i] <= mNumberOfElements) { auto slice = runtime::ITensor::slice( concurrenceResource.mBuffers[bufferId.value()], preBufferEleSize, requestedNumberOfElements[i]); preBufferEleSize += requestedNumberOfElements[i]; bufferCoverTargetNum++; retSplitCaches.push_back(std::move(slice)); } else { retSplitCaches.push_back(bufferManagerToUse.gpu( runtime::ITensor::makeShape({static_cast(requestedNumberOfElements[i])}), mDataType)); } } TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum); if (bufferCoverTargetNum < static_cast(targetNum)) { TLLM_LOG_WARNING( "CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic " "buffer which will fail with NIXL backend. It is recommended to set " "cacheTransceiverConfig.MaxTokensInBuffer (cache_transceiver_config.max_tokens_in_buffer in config " "YAML file) to a value greater than the maximum ISL of the processed requests. Otherwise, performance " "may be degraded or transfer may fail. requestedNumberOfElements.size():%ld, " "mNumberOfElements:%ld, requestedNumberOfElements[0]:%ld", bufferCoverTargetNum, targetNum, requestedNumberOfElements.size(), mNumberOfElements, requestedNumberOfElements[0]); } } else { for (int i = 0; i < targetNum; i++) { retSplitCaches.push_back(bufferManagerToUse.gpu( runtime::ITensor::makeShape({static_cast(requestedNumberOfElements[i])}), mDataType)); } bufferCoverTargetNum = targetNum; } return std::make_tuple(retSplitCaches, bufferCoverTargetNum, mOnlyUseDynamicBuffer); } void CacheTransBufferManager::allocateBuffer() { if (mOnlyUseDynamicBuffer) { return; } mNumberOfElements = mTransferBufferSize / common::getDTypeSize(mDataType); mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0); mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0); if (mUseFabricMemory) { mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount); for (size_t i = 0; i < mSendBufferCount; i++) { mFabricMemory.emplace_back(std::make_unique(mTransferBufferSize)); mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType, runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mNumberOfElements); } for (size_t i = 0; i < mRecvBufferCount; i++) { mFabricMemory.emplace_back(std::make_unique(mTransferBufferSize)); mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType, runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mNumberOfElements); } } else if (common::getEnvKVCacheTransferUseAsyncBuffer()) { for (size_t i = 0; i < mSendBufferCount; i++) { mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); } for (size_t i = 0; i < mRecvBufferCount; i++) { mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpu(runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); } mBufferManager.getStream().synchronize(); } else { for (size_t i = 0; i < mSendBufferCount; i++) { mConcurrenceSendResource.mBuffers[i] = mBufferManager.gpuSync( runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); } for (size_t i = 0; i < mRecvBufferCount; i++) { mConcurrenceRecvResource.mBuffers[i] = mBufferManager.gpuSync( runtime::ITensor::makeShape({static_cast(mNumberOfElements)}), mDataType); } } } std::optional CacheTransBufferManager::assignBufferIndex( ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer) { if (onlyUseDynamicBuffer) { return std::nullopt; } std::unique_lock lk(resource.mBuffersMutex); resource.mBuffersCV.wait( lk, [&resource, bufferCount]() { return static_cast(resource.mConcurrence) < bufferCount; }); int bufferId = -1; for (size_t i = 0; i < bufferCount; i++) { if (resource.mBufferIndexFlag[i] == 0) { bufferId = i; resource.mBufferIndexFlag[bufferId] = 1; resource.mConcurrence++; break; } } TLLM_CHECK_WITH_INFO(bufferId >= 0 && static_cast(bufferId) < bufferCount, " assignBufferIndex: Buffer index already assigned"); return bufferId; } void CacheTransBufferManager::freeBufferIndex( ConcurrenceResource& resource, std::optional bufferId, size_t bufferCount, bool onlyUseDynamicBuffer) { if (onlyUseDynamicBuffer) { return; } if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < bufferCount); { std::scoped_lock lk(resource.mBuffersMutex); resource.mBufferIndexFlag[bufferId.value()] = 0; } resource.mConcurrence--; resource.mBuffersCV.notify_one(); } } size_t CacheTransBufferManager::getRecvBufferCount() { return mRecvBufferCount; } size_t CacheTransBufferManager::getSendBufferCount() { return mSendBufferCount; } } // namespace tensorrt_llm::batch_manager::kv_cache_manager