/* * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tensorrt_llm/batch_manager/peftCacheManager.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/tllmException.h" #include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/loraCache.h" #include "tensorrt_llm/runtime/loraUtils.h" #include "tensorrt_llm/runtime/modelConfig.h" #include "tensorrt_llm/runtime/utils/numpyUtils.h" #include "tensorrt_llm/runtime/utils/runtimeUtils.h" #include "tensorrt_llm/runtime/workerPool.h" #include "tensorrt_llm/runtime/worldConfig.h" #include #include #include #include #include #include #include #include #include namespace tensorrt_llm::batch_manager { PeftTaskNotCachedException::PeftTaskNotCachedException(std::string const& msg) : runtime::LoraExpectedException(msg) { } PeftTaskNotCachedException::~PeftTaskNotCachedException() noexcept = default; std::pair PeftCacheManager::getMaxNumSlots(PeftCacheManagerConfig const& config, nvinfer1::DataType dataType, uint64_t pageWidth, uint64_t max1dModSize, runtime::BufferManager const& bufferManager) { TLLM_LOG_DEBUG("max1dModeSize=%llu", max1dModSize); TLLM_LOG_DEBUG("pageWidth=%llu", pageWidth); uint64_t const pageWidthSize = pageWidth * static_cast(runtime::BufferDataType(dataType).getSize()); uint64_t totalHostSlots, totalDeviceSlots; TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (config.numHostModuleLayer > 0) { totalHostSlots = static_cast(config.numHostModuleLayer) * common::ceilDiv(max1dModSize, pageWidth); TLLM_LOG_DEBUG("totalHostSlots=%llu", totalHostSlots); } else { auto const memSize = config.hostCacheSize.value_or(PeftCacheManagerConfig::kDefaultHostCacheSize); totalHostSlots = memSize / pageWidthSize; TLLM_LOG_DEBUG("memSize: %llu, totalHostSlots: %llu", memSize, totalHostSlots); } if (config.numDeviceModuleLayer > 0) { totalDeviceSlots = static_cast(config.numDeviceModuleLayer) * common::ceilDiv(max1dModSize, pageWidth); } else { auto const memPercent = config.deviceCachePercent.value_or(PeftCacheManagerConfig::kDefaultDeviceCachePercent); auto const [freeMem, totalMem] = common::getDeviceMemoryInfo(false); totalDeviceSlots = static_cast(static_cast(memPercent) * static_cast(freeMem + bufferManager.memoryPoolFree()) / static_cast(pageWidthSize)); } auto hostMem = totalHostSlots * pageWidthSize; auto deviceMem = totalDeviceSlots * pageWidthSize; TLLM_LOG_INFO("Using " + std::to_string(hostMem) + " bytes for LoRA host cache"); TLLM_LOG_INFO("Using " + std::to_string(deviceMem) + " bytes for LoRA device cache"); TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); return std::make_pair(totalHostSlots, totalDeviceSlots); } std::pair PeftCacheManager::getPageManagerConfig(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager) { auto const tpSize = worldConfig.getTensorParallelism(); auto const ppSize = worldConfig.getPipelineParallelism(); auto const numLocalLayers = modelConfig.getNbAttentionLayers(ppSize); uint64_t min1dModSize = std::numeric_limits::max(); // used to setup the pageWidth uint64_t total1dModSize = 0; uint64_t total1lSlots = 0; // the slots we need for each layer, summing the slots of all modules uint64_t max1dModSize = 0; // used to compute the totalHostSlots and totalDeviceSlots for (auto const& module : modelConfig.getLoraModules()) { uint64_t const oneDSize = static_cast(module.localInDim(tpSize) + module.localOutDim(tpSize)); TLLM_LOG_DEBUG("oneDSize=%llu", oneDSize); min1dModSize = std::min(min1dModSize, oneDSize); max1dModSize = std::max(max1dModSize, oneDSize); total1dModSize += oneDSize; } TLLM_LOG_DEBUG("total1dModSize=%llu", total1dModSize); auto const pageWidth = min1dModSize; for (auto const& module : modelConfig.getLoraModules()) { uint64_t const oneDSize = static_cast(module.localInDim(tpSize) + module.localOutDim(tpSize)); total1lSlots += config.optimalAdapterSize * common::ceilDiv(oneDSize, pageWidth); } uint64_t const max1dLoraSize = total1dModSize * static_cast(numLocalLayers); uint64_t const totalSlots = total1lSlots * static_cast(numLocalLayers); uint64_t const maxLoraSize = config.maxAdapterSize * max1dLoraSize; uint64_t const minNumSlots = common::ceilDiv(config.maxAdapterSize * max1dModSize, pageWidth); uint64_t const numSlotsPerPage = std::max(totalSlots, minNumSlots); uint64_t const minTotalSlots = common::ceilDiv(config.maxAdapterSize * max1dLoraSize, pageWidth); TLLM_LOG_DEBUG("max1dModSize=%llu", max1dModSize); auto [totalHostSlots, totalDeviceSlots] = getMaxNumSlots(config, modelConfig.getDataType(), pageWidth, max1dModSize, bufferManager); TLLM_CHECK_WITH_INFO(totalHostSlots >= minTotalSlots, "Not enough space allocated to host LoRA cache to hold 1 max sized LoRA %lu < %lu", totalHostSlots, minTotalSlots); TLLM_CHECK_WITH_INFO(totalDeviceSlots >= minTotalSlots, "Not enough space allocated to device LoRA cache to hold 1 max sized LoRA %lu < %lu", totalDeviceSlots, minTotalSlots); uint64_t const totalHostPages = common::ceilDiv(totalHostSlots, numSlotsPerPage); uint64_t const totalDevicePages = common::ceilDiv(totalDeviceSlots, numSlotsPerPage); uint64_t const totalMaxLorasHost = totalHostPages * numSlotsPerPage * pageWidth / maxLoraSize; uint64_t const totalMaxLorasDevice = totalDevicePages * numSlotsPerPage * pageWidth / maxLoraSize; TLLM_LOG_INFO("Max LoRA size is %llu", maxLoraSize); TLLM_LOG_INFO("LoRA host Cache can hold %llu max sized LoRAs", totalMaxLorasHost); TLLM_LOG_INFO("LoRA device Cache can hold %llu max sized LoRAs", totalMaxLorasDevice); TLLM_CHECK(std::numeric_limits::max() >= totalHostPages); TLLM_CHECK(std::numeric_limits::max() >= totalDevicePages); TLLM_CHECK(std::numeric_limits::max() >= numSlotsPerPage); TLLM_CHECK(std::numeric_limits::max() >= pageWidth); runtime::LoraCachePageManagerConfig hostPageConfig(runtime::MemoryType::kCPU, modelConfig.getDataType(), totalHostPages, config.maxPagesPerBlockHost, numSlotsPerPage, pageWidth, 0); runtime::LoraCachePageManagerConfig devicePageConfig(runtime::MemoryType::kGPU, modelConfig.getDataType(), totalDevicePages, config.maxPagesPerBlockDevice, numSlotsPerPage, pageWidth, config.numCopyStreams); return std::make_pair(hostPageConfig, devicePageConfig); } void PeftCacheManager::prefetchLoraWeights(std::string const& modelDir, runtime::BufferManager const& bufferManager) { // This function loads LoRA weights from modelDir. In the folder, users can put many // folders for different lora tasks. // For example, assume we want to store lora weights in modelDir and there are three // lora tasks `0`, `1` and `3`, then the architecture of the folder would be like // modelDir // ├── 0 // │ ├── model.lora_config.npy // │ └── model.lora_weights.npy // ├── 1 // │ ├── model.lora_config.npy // │ └── model.lora_weights.npy // └── 3 // ├── model.lora_config.npy // └── model.lora_weights.npy // // If the name of the lora task is not digit, will print warning and // skip loading lora weight from the folder TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); namespace fs = std::filesystem; std::vector tasks; if (!fs::exists(modelDir) || !fs::is_directory(modelDir)) { TLLM_LOG_DEBUG("Cannot find the %s, skipping prefetching the lora weights.", modelDir.c_str()); return; } // collect the lora tasks under modelDir for (auto const& entry : fs::directory_iterator(modelDir)) { if (fs::is_directory(entry.path())) { std::string task_name = entry.path().filename().string(); if (all_of(task_name.begin(), task_name.end(), ::isdigit)) { tasks.push_back(task_name); } else { TLLM_LOG_WARNING( "lora task name %s is invalid, skipping to load lora weight from this folder.", task_name.c_str()); } } } TLLM_LOG_DEBUG("find %u lora tasks to prefetch.", tasks.size()); // load lora tasks one by one using TensorPtr = runtime::ITensor::SharedPtr; auto const multiLoraPath = fs::path{modelDir}; for (std::uint32_t taskId = 0; taskId < tasks.size(); ++taskId) { auto const weightsFn = (multiLoraPath / tasks[taskId] / "model.lora_weights.npy").string(); auto const configFn = (multiLoraPath / tasks[taskId] / "model.lora_config.npy").string(); TensorPtr weights = runtime::utils::loadNpy(bufferManager, weightsFn, runtime::MemoryType::kCPU); TensorPtr config = runtime::utils::loadNpy(bufferManager, configFn, runtime::MemoryType::kCPU); TLLM_LOG_DEBUG("prefetch lora task %s", tasks[taskId].c_str()); mHostLoraCache->put(std::stoi(tasks[taskId]), weights, config, true); } TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } PeftCacheManager::PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager) : mModelConfig(modelConfig) , mWorldConfig(worldConfig) , mDevice{runtime::utils::initDevice(worldConfig)} { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); auto cfg = config; cfg.optimalAdapterSize = std::min(cfg.optimalAdapterSize, modelConfig.getMaxLoraRank()); cfg.maxAdapterSize = std::min(cfg.maxAdapterSize, modelConfig.getMaxLoraRank()); auto [hostCacheConfig, deviceCacheConfig] = getPageManagerConfig(cfg, modelConfig, worldConfig, bufferManager); mHostLoraCache = std::make_unique(hostCacheConfig, modelConfig, worldConfig, bufferManager); mDeviceLoraCache = std::make_unique(deviceCacheConfig, modelConfig, worldConfig, bufferManager); mPutWorkerPool = std::make_shared(cfg.numPutWorkers, mDevice); mEnsureWorkerPool = std::make_unique(config.numEnsureWorkers, mDevice); if (config.loraPrefetchDir.has_value()) { prefetchLoraWeights(config.loraPrefetchDir.value(), bufferManager); } TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } void PeftCacheManager::addRequestPeft(std::shared_ptr llmRequest, bool tryGpuCache) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); auto optTaskId = llmRequest->getLoraTaskId(); auto optLoraWeights = llmRequest->getLoraWeights(); auto optLoraConfig = llmRequest->getLoraConfig(); if (optTaskId || optLoraWeights || optLoraConfig) { runtime::lora::loraValidateRequestTensors(optTaskId, optLoraWeights, optLoraConfig, mModelConfig, mWorldConfig); } else { // no lora to add to cache so we're done return; } auto const taskId = optTaskId.value(); if (!optLoraWeights || !optLoraConfig) { // Throw special exception that's logged as warning in executor if (!isTaskCached(taskId)) { std::string errMsg = "LoRA task " + std::to_string(taskId) + " not found in cache. Please send LoRA weights with request"; throw PeftTaskNotCachedException(errMsg); } } auto const reqId = llmRequest->mRequestId; TLLM_LOG_DEBUG("addRequestPeft taskId=" + std::to_string(taskId) + " reqId=" + std::to_string(reqId)); updateTaskState(taskId, reqId); { // if we are already processing this task we are done std::lock_guard lk(mPutFuturesMutex); if (mPutFutures.count(taskId)) { TLLM_LOG_DEBUG( "addRequestPeft haveFuture skip taskId=" + std::to_string(taskId) + " reqId=" + std::to_string(reqId)); return; } } bool loadNeeded; try { if (optLoraWeights && optLoraConfig) { TLLM_LOG_DEBUG("addRequestPeft put taskId=" + std::to_string(taskId) + " reqId=" + std::to_string(reqId)); mHostLoraCache->put(taskId, optLoraWeights.value(), optLoraConfig.value(), false); loadNeeded = true; } else { TLLM_LOG_DEBUG("addRequestPeft bump taskId=" + std::to_string(taskId) + " reqId=" + std::to_string(reqId)); mHostLoraCache->bump(taskId); loadNeeded = false; } } catch (runtime::LoraCacheFullException const& e) { std::string errMsg("PEFT Cache is full. Could not store taskId=" + std::to_string(taskId)); TLLM_LOG_ERROR(errMsg); updateTaskState(taskId, reqId, true, false); // re-throw with better error message throw runtime::LoraCacheFullException(errMsg); } catch (std::runtime_error const& e) { updateTaskState(taskId, reqId, true, false); TLLM_THROW("Error storing task=%lu in PEFT cache -- %s", taskId, e.what()); } auto fn = [taskId, req = llmRequest, loadNeeded, this]() { auto optWeights = req->getLoraWeights(); auto optConfig = req->getLoraConfig(); if (loadNeeded && optWeights.has_value() && optConfig.has_value()) { TLLM_LOG_DEBUG( "addRequestPeft load taskId=" + std::to_string(taskId) + " reqId=" + std::to_string(req->mRequestId)); mHostLoraCache->loadWeights(taskId, optWeights.value(), optConfig.value()); // free memory associated with lora weights in llmRequest req->clearLoraWeights(); req->clearLoraConfig(); } // TODO (grclark) pre-loading nvbug 4547061 if (false) { try { mHostLoraCache->copyTask(taskId, *mDeviceLoraCache, false); } catch (std::runtime_error& e) { TLLM_LOG_DEBUG("failed to copy task " + std::to_string(taskId) + " to gpu cache -- " + e.what()); } } #ifndef NDEBUG if (!mHostLoraCache->isLoaded(taskId)) { throw std::runtime_error("Expected task to be loaded at the end of put " + std::to_string(taskId)); } if (mHostLoraCache->isDone(taskId)) { throw std::runtime_error("Expected task to be in progress at the end of put " + std::to_string(taskId)); } #endif }; auto putFuture = mPutWorkerPool->enqueue(fn); { std::lock_guard lk(mPutFuturesMutex); mPutFutures.try_emplace(taskId, std::move(putFuture)); } TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } std::tuple>, std::map>> PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVector const& generationRequests) { std::map> taskIdToReqIds; std::map> taskIdToFuture; std::lock_guard futuresLock(mPutFuturesMutex); for (auto const& requests : {contextRequests, generationRequests}) { for (auto const& llmReq : requests) { if (llmReq->getLoraTaskId().has_value()) { auto const taskId = llmReq->getLoraTaskId().value(); if (!taskIdToReqIds.count(taskId)) { taskIdToReqIds.try_emplace(taskId, std::vector{}); } taskIdToReqIds.at(taskId).push_back(llmReq->mRequestId); if (mPutFutures.count(taskId)) { TLLM_LOG_DEBUG("Ensure batch, has future for " + std::to_string(taskId)); taskIdToFuture.try_emplace(taskId, std::move(mPutFutures.at(taskId))); mPutFutures.erase(taskId); } else if (!taskIdToFuture.count(taskId)) { /* * if we don't find a future in mPutFutures, we may have already put one in * taskIdToFutures (ie if 2 requests have the same taskId) * If no future is found we create a dummy future for the task */ TLLM_LOG_DEBUG("Ensure batch, do not have future for " + std::to_string(taskId)); std::promise p; p.set_value(); taskIdToFuture.try_emplace(taskId, p.get_future()); } } } } return {std::move(taskIdToFuture), taskIdToReqIds}; } PeftCacheManager::PeftTable PeftCacheManager::ensureBatch( RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (resetGpuCache) { resetDeviceCache(); } auto [taskIdToFuture_, taskIdToReqIds] = getTaskMaps(contextRequests, generationRequests); auto taskIdToFuture = std::move(taskIdToFuture_); // captured structured bindings are a C++20 extension std::map>> ensureFutures; for (auto& [taskId, taskFuture] : taskIdToFuture) { auto fn = [&taskIdToFuture, taskId = taskId, this]() -> std::vector { // TODO (grclark) it should be possible to move capture taskFuture instead of doing a second lookup // And you can, which required this lambda to be mutable, which doesn't work with WorkerPool auto&& taskFuture = taskIdToFuture.at(taskId); try { taskFuture.get(); } catch (std::runtime_error& e) { throw std::runtime_error("Caught Exception while storing peft weights -- " + std::string(e.what())); } try { mHostLoraCache->copyTask(taskId, *mDeviceLoraCache); } catch (std::runtime_error& e) { throw std::runtime_error("Caught exception during copyTask ensure batch -- " + std::string(e.what())); } return mDeviceLoraCache->get(taskId); }; auto f = mEnsureWorkerPool->enqueue(fn); ensureFutures.try_emplace(taskId, std::move(f)); } PeftTable peftTable{}; for (auto const& [taskId, reqIds] : taskIdToReqIds) { auto&& f = ensureFutures.at(taskId); auto const values = f.get(); for (auto const& reqId : reqIds) { peftTable.try_emplace(reqId, values); } } TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); return peftTable; } bool PeftCacheManager::isTaskCached(uint64_t taskId) const { return mHostLoraCache->has(taskId); } bool PeftCacheManager::isTaskDone(uint64_t taskId) const { return mHostLoraCache->isDone(taskId); } bool PeftCacheManager::isTaskDoneDevice(uint64_t taskId) const { return mDeviceLoraCache->isDone(taskId); } void PeftCacheManager::updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate, bool pause) { if (!terminate) { if (!mTaskIdToReqIds.count(taskId)) { mTaskIdToReqIds.try_emplace(taskId, std::unordered_set{}); } mTaskIdToReqIds.at(taskId).insert(reqId); if (mTaskIdToPausedReqIds.count(taskId)) { mTaskIdToPausedReqIds.at(taskId).erase(reqId); if (mTaskIdToPausedReqIds.at(taskId).empty()) { mTaskIdToPausedReqIds.erase(taskId); } } } else { auto activeTaskIt = mTaskIdToReqIds.find(taskId); auto pauseTakskIt = mTaskIdToPausedReqIds.find(taskId); if (activeTaskIt != mTaskIdToReqIds.end()) { activeTaskIt->second.erase(reqId); if (activeTaskIt->second.empty()) { mTaskIdToReqIds.erase(taskId); } } if (pause) { if (pauseTakskIt == mTaskIdToPausedReqIds.end()) { mTaskIdToPausedReqIds.try_emplace(taskId, std::unordered_set{}); } mTaskIdToPausedReqIds.at(taskId).insert(reqId); } else { if (pauseTakskIt != mTaskIdToPausedReqIds.end()) { pauseTakskIt->second.erase(reqId); if (pauseTakskIt->second.empty()) { mTaskIdToPausedReqIds.erase(taskId); } } } if (!mTaskIdToReqIds.count(taskId)) { // paused taskIds get removed from gpu cache but not host cache mDeviceLoraCache->markTaskDone(taskId); if (!mTaskIdToPausedReqIds.count(taskId)) { { std::lock_guard lk(mPutFuturesMutex); mPutFutures.erase(taskId); TLLM_LOG_DEBUG( "erase task " + std::to_string(taskId) + " future size=" + std::to_string(mPutFutures.size())); } mHostLoraCache->markTaskDone(taskId); } } } } void PeftCacheManager::resetDeviceCache() { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); mDeviceLoraCache->markAllDone(); TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } void PeftCacheManager::markRequestDone(LlmRequest const& llmReq, bool pause) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); // mDeviceLoraCache->markAllDone(); if (!llmReq.getLoraTaskId().has_value()) { return; } auto const taskId = llmReq.getLoraTaskId().value(); auto const reqId = llmReq.mRequestId; updateTaskState(taskId, reqId, true, pause); TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } SizeType32 PeftCacheManager::getMaxDevicePages() const { return mDeviceLoraCache->getNumPages(); } SizeType32 PeftCacheManager::getMaxHostPages() const { return mHostLoraCache->getNumPages(); } SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr llmRequest) const { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (llmRequest->getLoraTaskId().has_value()) { try { return mHostLoraCache->determineNumPages(llmRequest->getLoraTaskId().value()); } catch (std::runtime_error& e) { if (llmRequest->getLoraConfig().has_value()) { return mHostLoraCache->determineNumPages(llmRequest->getLoraConfig().value()); } else { throw; } } } TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); return 0; } std::unordered_map> const& PeftCacheManager::getActiveTasks() const { return mTaskIdToReqIds; } std::unordered_map> const& PeftCacheManager::getPausedTasks() const { return mTaskIdToPausedReqIds; } void NoOpPeftCacheManager::addRequestPeft(std::shared_ptr llmRequest, bool tryGpuCache) {} PeftCacheManager::PeftTable NoOpPeftCacheManager::ensureBatch( RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache) { return PeftTable{}; } void NoOpPeftCacheManager::resetDeviceCache() {} void NoOpPeftCacheManager::markRequestDone(LlmRequest const& llmReq, bool pause) {} SizeType32 NoOpPeftCacheManager::getMaxDevicePages() const { return std::numeric_limits::max(); } SizeType32 NoOpPeftCacheManager::getMaxHostPages() const { return std::numeric_limits::max(); } SizeType32 NoOpPeftCacheManager::determineNumPages(std::shared_ptr llmReqeust) const { return 0; } } // namespace tensorrt_llm::batch_manager