From 1dc49b266e5423cfc463a42d6b6b6711acbb850d Mon Sep 17 00:00:00 2001 From: Jiayu Chang <40067028+JyChang012@users.noreply.github.com> Date: Thu, 22 Jan 2026 21:01:18 +0800 Subject: [PATCH] [https://nvbugs/5322131][feat] Multi-LoRA serving with CUDA Graph (#8279) Signed-off-by: Jiayu Chang --- .../batch_manager/peftCacheManager.h | 14 +- .../batch_manager/peftCacheManager.cpp | 42 +- .../kernels/cuda_graph_grouped_gemm.cu | 372 ++++++++++++++ .../kernels/cuda_graph_grouped_gemm.h | 63 +++ cpp/tensorrt_llm/kernels/lora/lora.cpp | 2 +- .../loraGroupGEMMParamFillRowReorderFusion.cu | 408 +++++++++++++++ .../loraGroupGEMMParamFillRowReorderFusion.h | 77 +++ .../nanobind/batch_manager/kvCacheManager.cpp | 6 + .../pybind/batch_manager/kvCacheManager.cpp | 5 + cpp/tensorrt_llm/runtime/loraManager.h | 2 +- cpp/tensorrt_llm/thop/loraOp.cpp | 224 +++++++++ tensorrt_llm/_torch/modules/attention.py | 3 - .../_torch/peft/lora/adapter_slot_manager.py | 130 +++++ .../peft/lora/cuda_graph_lora_manager.py | 175 +++++++ .../peft/lora/cuda_graph_lora_params.py | 341 +++++++++++++ tensorrt_llm/_torch/peft/lora/layer.py | 468 ++++++++++++++++-- tensorrt_llm/_torch/pyexecutor/_util.py | 2 + tensorrt_llm/_torch/pyexecutor/llm_request.py | 3 - .../_torch/pyexecutor/model_engine.py | 97 +++- .../_torch/pyexecutor/resource_manager.py | 24 +- .../_torch/thop/parallel/test_custom_ops.py | 2 + tests/unittest/llmapi/lora_test_utils.py | 242 ++++++++- .../llmapi/test_llm_multi_gpu_pytorch.py | 25 +- tests/unittest/llmapi/test_llm_pytorch.py | 181 ++++--- tests/unittest/llmapi/test_utils.py | 30 ++ 25 files changed, 2766 insertions(+), 172 deletions(-) create mode 100644 cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.cu create mode 100644 cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.h create mode 100644 cpp/tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.cu create mode 100644 cpp/tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.h create mode 100644 tensorrt_llm/_torch/peft/lora/adapter_slot_manager.py create mode 100644 tensorrt_llm/_torch/peft/lora/cuda_graph_lora_manager.py create mode 100644 tensorrt_llm/_torch/peft/lora/cuda_graph_lora_params.py diff --git a/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h index 0f844ac0aa..cf65753783 100644 --- a/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/peftCacheManager.h @@ -57,7 +57,10 @@ class BasePeftCacheManager public: using LlmRequestPtr = std::shared_ptr; using RequestVector = std::vector; - using PeftTable = std::map>; + using PeftTable = std::unordered_map>; + using TaskPeftTable = std::unordered_map>; + using TaskIdToReqIds = std::unordered_map>; + using EnsureBatchTaskResult = std::tuple; virtual ~BasePeftCacheManager() = default; @@ -99,6 +102,8 @@ public: class PeftCacheManager : public BasePeftCacheManager { public: + using EnsureBatchTaskResult = BasePeftCacheManager::EnsureBatchTaskResult; + PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager); @@ -109,12 +114,17 @@ public: PeftTable ensureBatch(RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache = false) override; + EnsureBatchTaskResult ensureBatchMapTaskId( + RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache = false); + [[nodiscard]] bool isTaskCached(uint64_t taskId) const; [[nodiscard]] bool isTaskDone(uint64_t taskId) const; [[nodiscard]] bool isTaskDoneDevice(uint64_t taskId) const; + [[nodiscard]] bool isTaskCachedDevice(uint64_t const taskId) const; + void resetDeviceCache() override; void markRequestDone(LlmRequest const& llmReq, bool pause = false) override; @@ -159,7 +169,7 @@ private: std::unordered_map> mTaskIdToReqIds; std::unordered_map> mTaskIdToPausedReqIds; - std::tuple>, std::map>> getTaskMaps( + std::tuple>, TaskIdToReqIds> getTaskMaps( RequestVector const& contextRequests, RequestVector const& generationRequests); runtime::ModelConfig mModelConfig; diff --git a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp index cc62bd3eb0..0bf9a989fd 100644 --- a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp @@ -373,11 +373,11 @@ void PeftCacheManager::addRequestPeft(std::shared_ptr llmRequest, bo TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } -std::tuple>, std::map>> +std::tuple>, BasePeftCacheManager::TaskIdToReqIds> PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVector const& generationRequests) { - std::map> taskIdToReqIds; - std::map> taskIdToFuture; + TaskIdToReqIds taskIdToReqIds; + std::unordered_map> taskIdToFuture; std::lock_guard futuresLock(mPutFuturesMutex); for (auto const& requests : {contextRequests, generationRequests}) { @@ -415,7 +415,7 @@ PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVecto return {std::move(taskIdToFuture), taskIdToReqIds}; } -PeftCacheManager::PeftTable PeftCacheManager::ensureBatch( +PeftCacheManager::EnsureBatchTaskResult PeftCacheManager::ensureBatchMapTaskId( RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache) { TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); @@ -426,7 +426,7 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch( auto [taskIdToFuture_, taskIdToReqIds] = getTaskMaps(contextRequests, generationRequests); auto taskIdToFuture = std::move(taskIdToFuture_); // captured structured bindings are a C++20 extension - std::map>> ensureFutures; + std::unordered_map>> ensureFutures; for (auto& [taskId, taskFuture] : taskIdToFuture) { auto fn = [&taskIdToFuture, taskId = taskId, this]() -> std::vector @@ -457,18 +457,31 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch( ensureFutures.try_emplace(taskId, std::move(f)); } - PeftTable peftTable{}; + TaskPeftTable peftTable{}; for (auto const& [taskId, reqIds] : taskIdToReqIds) { auto&& f = ensureFutures.at(taskId); auto const values = f.get(); - for (auto const& reqId : reqIds) - { - peftTable.try_emplace(reqId, values); - } + peftTable.try_emplace(taskId, values); } TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); - return peftTable; + return {std::move(peftTable), std::move(taskIdToReqIds)}; +} + +PeftCacheManager::PeftTable PeftCacheManager::ensureBatch( + RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache) +{ + auto [taskTable, taskIdToReqIds] = ensureBatchMapTaskId(contextRequests, generationRequests, resetGpuCache); + PeftTable requestTable{}; + for (auto const& [taskId, values] : taskTable) + { + auto const& reqIds = taskIdToReqIds.at(taskId); + for (auto const reqId : reqIds) + { + requestTable.try_emplace(reqId, values); + } + } + return requestTable; } bool PeftCacheManager::isTaskCached(uint64_t taskId) const @@ -486,6 +499,11 @@ bool PeftCacheManager::isTaskDoneDevice(uint64_t taskId) const return mDeviceLoraCache->isDone(taskId); } +bool PeftCacheManager::isTaskCachedDevice(uint64_t const taskId) const +{ + return mDeviceLoraCache->has(taskId); +} + void PeftCacheManager::updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate, bool pause) { if (!terminate) @@ -645,3 +663,5 @@ SizeType32 NoOpPeftCacheManager::determineNumPages(std::shared_ptr l return 0; } } // namespace tensorrt_llm::batch_manager + +// TODO: merge C++ LoRA caching status with Py Slot manager diff --git a/cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.cu b/cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.cu new file mode 100644 index 0000000000..81e9479777 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.cu @@ -0,0 +1,372 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_graph_grouped_gemm.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h" +#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h" + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels +{ + +/** + * Template for CUDA Graph compatible grouped GEMM that directly uses GPU tensors + */ +template +void cudaGraphGroupedGemmTemplate(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, + void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, + cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream) +{ + using ElementA = cutlassType; + using ElementB = cutlassType; + using ElementOutput = cutlassType; + using ElementAccumulator = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped, cutlass::gemm::GemmShape, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, kStages, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmGrouped; + + float alpha = 1.0f; + float beta = 0.0f; + typename Gemm::EpilogueOutputOp::Params epilogueOp(alpha, beta); + + auto ptrA = reinterpret_cast(ptrAGpu); + auto ptrB = reinterpret_cast(ptrBGpu); + auto ptrC = reinterpret_cast(ptrCGpu); + auto ptrD = reinterpret_cast(ptrDGpu); + + Gemm gemmOp; + + int threadblockCount = Gemm::sufficient(nullptr, problemCount); + + typename Gemm::Arguments args(problemSizesPtr, // GPU problem sizes + problemCount, // Problem count + threadblockCount, // Threadblock count + epilogueOp, // Epilogue operation + ptrA, // GPU pointer array A + ptrB, // GPU pointer array B + ptrC, // GPU pointer array C (can be nullptr) + ptrD, // GPU pointer array D + ldaGpu, // Precomputed leading dimension A (on GPU) + ldbGpu, // Precomputed leading dimension B (on GPU) + ldcGpu, // Precomputed leading dimension C (on GPU) + lddGpu, // Precomputed leading dimension D (on GPU) + hostMaxProblemSizesPtr); + + static_assert(Gemm::BaseKernel::ProblemVisitor::kRequiresPrecomputation == false, + "Grouped GEMM with CUDA Graph cannot use precompution."); + { + cutlass::Status status = gemmOp.can_implement(args); + TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, + "Grouped GEMM cannot be implemented with the given arguments, Error: %s", + cutlass::cutlassGetStatusString(status)); + } + + at::Tensor workspace; + void* gemmWorkspace = nullptr; + size_t const requiredWorkspace = gemmOp.get_workspace_size(args); + if (requiredWorkspace > 0) + { + auto const workspaceTensorOptions = at::TensorOptions().dtype(at::kByte).device(at::kCUDA); + workspace = at::empty({static_cast(requiredWorkspace)}, workspaceTensorOptions); + gemmWorkspace = workspace.data_ptr(); + } + + cutlass::Status status = gemmOp.initialize(args, gemmWorkspace); + TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to initialize grouped GEMM"); + + status = gemmOp.run(stream); + sync_check_cuda_error(stream); + TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to execute grouped GEMM"); +} + +template +void cudaGraphGroupedGemmType(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, + void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, + nvinfer1::DataType dataType, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream) +{ + if (dataType == nvinfer1::DataType::kHALF) + { + cudaGraphGroupedGemmTemplate( + problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, + hostMaxProblemSizesPtr, stream); + } +#ifdef ENABLE_BF16 + else if (dataType == nvinfer1::DataType::kBF16) + { + cudaGraphGroupedGemmTemplate( + problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, + hostMaxProblemSizesPtr, stream); + } +#endif + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported data type for CUDA Graph grouped GEMM"); + } +} + +void cudaGraphGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, void** ptrBGpu, + void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, bool isLoraIn, + nvinfer1::DataType dataType, int minKN, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream) +{ + if (isLoraIn) + { + if (minKN >= 8) + { + cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, + ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream); + } + else if (minKN >= 4) + { + cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 4, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, + ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream); + } + else if (minKN >= 2) + { + cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 2, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, + ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream); + } + else if (minKN >= 1) + { + cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 1, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, + ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream); + } + } + else + { + if (minKN >= 8) + { + cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, + ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream); + } + else if (minKN >= 4) + { + cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 4, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, + ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream); + } + else if (minKN >= 2) + { + cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 2, 8, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, + ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream); + } + else + { + cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 1, 8, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, + ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream); + } + } +} + +/** + * Template for CUDA Graph compatible split-K grouped GEMM + */ +template +void cudaGraphSplitKGroupedGemmTemplate(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, + void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, + int splitKSlices, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, int64_t* splitKOffsetsGpu, cudaStream_t stream) +{ + using ElementA = cutlassType; + using ElementB = cutlassType; + using ElementOutput = cutlassType; + using ElementAccumulator = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using GemmKernel = typename cutlass::gemm::kernel::DefaultSplitkGemmGrouped, cutlass::gemm::GemmShape, cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, kStages, + cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>::GemmKernel; + + using Gemm = cutlass::gemm::device::SplitkGemmGrouped; + + float alpha = 1.0f; + float beta = 0.0f; + typename Gemm::EpilogueOutputOp::Params epilogueOp(alpha, beta); + + auto ptrA = reinterpret_cast(ptrAGpu); + auto ptrB = reinterpret_cast(ptrBGpu); + auto ptrC = reinterpret_cast(ptrCGpu); + auto ptrD = reinterpret_cast(ptrDGpu); + + Gemm gemmOp; + + int threadblockCount = Gemm::sufficient(nullptr, problemCount); + + // Setup arguments for split-K grouped GEMM - using precomputed leading dimensions from GPU tensors + typename Gemm::Arguments args(problemSizesPtr, // GPU problem sizes + problemCount, // Problem count + threadblockCount, // Threadblock count + epilogueOp, // Epilogue operation + ptrA, // GPU pointer array A + ptrB, // GPU pointer array B + ptrC, // GPU pointer array C + ptrD, // GPU pointer array D + ldaGpu, // Precomputed leading dimension A (on GPU) + ldbGpu, // Precomputed leading dimension B (on GPU) + ldcGpu, // Precomputed leading dimension C (on GPU) + lddGpu, // Precomputed leading dimension D (on GPU) + hostMaxProblemSizesPtr, // Host problem sizes + splitKSlices, // Split-K factor + splitKOffsetsGpu); + + { + cutlass::Status status = gemmOp.can_implement(args); + TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, + "Split-K grouped GEMM cannot be implemented with the given arguments. Problem count: %d, Split-K slices: " + "%d, Error: %s", + problemCount, splitKSlices, cutlass::cutlassGetStatusString(status)); + } + + at::Tensor workspace; + void* gemmWorkspace = nullptr; + size_t const requiredWorkspace = gemmOp.get_workspace_size(args); + if (requiredWorkspace > 0) + { + workspace = at::empty( + {static_cast(requiredWorkspace)}, at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + gemmWorkspace = workspace.data_ptr(); + } + + cutlass::Status status = gemmOp.initialize(args, gemmWorkspace); + TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, + "Failed to initialize split-K grouped GEMM. Problem count: %d, Split-K slices: %d, Error: %s", problemCount, + splitKSlices, cutlass::cutlassGetStatusString(status)); + + status = gemmOp.run(stream); + sync_check_cuda_error(stream); + TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, + "Failed to execute split-K grouped GEMM. Problem count: %d, Split-K slices: %d, Error: %s", problemCount, + splitKSlices, cutlass::cutlassGetStatusString(status)); +} + +template +void cudaGraphSplitKGroupedGemmType(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, + void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, + nvinfer1::DataType dataType, int splitKSlices, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, + int64_t* splitKOffsetsGpu, cudaStream_t stream) +{ + if (dataType == nvinfer1::DataType::kHALF) + { + cudaGraphSplitKGroupedGemmTemplate( + problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, + splitKSlices, hostMaxProblemSizesPtr, splitKOffsetsGpu, stream); + } +#ifdef ENABLE_BF16 + else if (dataType == nvinfer1::DataType::kBF16) + { + cudaGraphSplitKGroupedGemmTemplate(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, + splitKSlices, hostMaxProblemSizesPtr, splitKOffsetsGpu, stream); + } +#endif + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported data type for CUDA Graph split-K grouped GEMM"); + } +} + +void cudaGraphSplitKGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, + void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, + bool isLoraIn, nvinfer1::DataType dataType, int splitKSlices, int minKN, + cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, int64_t* splitKOffsetsGpu, cudaStream_t stream) +{ + if (isLoraIn) + { + if (minKN >= 8) + { + cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, + ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices, + hostMaxProblemSizesPtr, splitKOffsetsGpu, stream); + } + else if (minKN >= 4) + { + cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 4, 4>(problemSizesPtr, problemCount, ptrAGpu, + ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices, + hostMaxProblemSizesPtr, splitKOffsetsGpu, stream); + } + else if (minKN >= 2) + { + cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 2, 2>(problemSizesPtr, problemCount, ptrAGpu, + ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices, + hostMaxProblemSizesPtr, splitKOffsetsGpu, stream); + } + else if (minKN >= 1) + { + cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 1, 2>(problemSizesPtr, problemCount, ptrAGpu, + ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices, + hostMaxProblemSizesPtr, splitKOffsetsGpu, stream); + } + } + else + { + if (minKN >= 8) + { + cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, + ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices, + hostMaxProblemSizesPtr, splitKOffsetsGpu, stream); + } + else if (minKN >= 4) + { + cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 4, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, + ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices, + hostMaxProblemSizesPtr, splitKOffsetsGpu, stream); + } + else if (minKN >= 2) + { + cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 2, 8, 2>(problemSizesPtr, problemCount, ptrAGpu, + ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices, + hostMaxProblemSizesPtr, splitKOffsetsGpu, stream); + } + else + { + cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 1, 8, 2>(problemSizesPtr, problemCount, ptrAGpu, + ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices, + hostMaxProblemSizesPtr, splitKOffsetsGpu, stream); + } + } +} + +} // namespace kernels + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.h b/cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.h new file mode 100644 index 0000000000..0eecccb788 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass/gemm_coord.h" +#include "tensorrt_llm/common/config.h" +#include +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels +{ + +/** + * @brief CUDA Graph compatible wrapper for grouped GEMM operations. + * + * This function accepts GPU pointers directly without any workspace for parameters, + * making it fully compatible with CUDA Graph capture and replay. + * + * @param problemSizesPtr GPU pointer to array of cutlass::gemm::GemmCoord + * @param problemCount Number of GEMM problems + * @param ptrAGpu GPU pointer to array of A matrix pointers + * @param ptrBGpu GPU pointer to array of B matrix pointers + * @param ptrCGpu GPU pointer to array of C matrix pointers (can be nullptr) + * @param ptrDGpu GPU pointer to array of D matrix pointers + * @param isLoraIn Whether this is for LoRA input transformation + * @param dataType Data type of the matrices + * @param minKN Minimum K*N value for kernel selection + * @param stream CUDA stream + */ +void cudaGraphGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, void** ptrBGpu, + void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, bool isLoraIn, + nvinfer1::DataType dataType, int minKN, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream); + +/** + * @brief CUDA Graph compatible wrapper for split-K grouped GEMM operations. + * + * Similar to cudaGraphGroupedGemm but uses split-K algorithm for better + * performance with certain problem sizes. No parameter workspace needed. + */ +void cudaGraphSplitKGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, + void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, + bool isLoraIn, nvinfer1::DataType dataType, int splitKSlices, int minKN, + cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, int64_t* splitKOffsetsGpu, cudaStream_t stream); + +} // namespace kernels + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/lora/lora.cpp b/cpp/tensorrt_llm/kernels/lora/lora.cpp index 167826be62..61f6af00fe 100644 --- a/cpp/tensorrt_llm/kernels/lora/lora.cpp +++ b/cpp/tensorrt_llm/kernels/lora/lora.cpp @@ -299,7 +299,7 @@ int LoraImpl::run(int64_t numTokens, int64_t numReqs, void const* input, int32_t + (loraModuleIdx * numTokens * mMaxLowRank + handled_token_num * mMaxLowRank) * typeSize)); auto const N2 = mOutHiddenSizes[loraModuleIdx]; - cutlass::gemm::GemmCoord problem_2(M, N2, N); + cutlass::gemm::GemmCoord problem_2(M, N2, N); // token_num, module_output_size, lora_rank problem_sizes_2.push_back(problem_2); ptrA_2.push_back(static_cast(static_cast(lowRankWorkSpace) + (loraModuleIdx * numTokens * mMaxLowRank + handled_token_num * mMaxLowRank) * typeSize)); diff --git a/cpp/tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.cu b/cpp/tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.cu new file mode 100644 index 0000000000..c3276ea487 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.cu @@ -0,0 +1,408 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loraGroupGEMMParamFillRowReorderFusion.h" + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" + +#include +#include +#include + +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels +{ +namespace +{ + +template +__forceinline__ constexpr T (&as_singleton_array(T& obj))[1] +{ + return reinterpret_cast(obj); +} + +enum ParamIndex +{ + IN_SIZES_INDEX, + OUT_SIZES_INDEX, + LDA_INDEX, + LDD_INDEX, + LDB_PRIME_INDEX, + LDD_PRIME_INDEX, + A_PTRS_INDEX, + D_PTRS_INDEX, + D_PRIME_PTRS_INDEX, + SPLITK_OFFSETS_INDEX, + PARAM_COUNT, +}; + +int constexpr VECTOR_LOAD_WIDTH = 16; +} // namespace + +/** + * Fused kernel for LoRA group GEMM parameter filling, row gather and zero fillings. + * Needs to be called with at least PARAM_COUNT blocks. And total number of threads need to be enough to reorder `input` + * and fill zeros for intermediate and output buffers. Specifically, (n_threads >= max(input_bytes_to_reorder, + * intermediate_bytes_to_fill, output_bytes_to_fill) / VECTOR_LOAD_WIDTH) + * + * Template parameters: + * - BlockDim: Number of threads per block (1D, >= max_lora_count * module_count, >= 256, divisible by 32) + * - MODULE_COUNT: Number of modules per layer + */ +template +__global__ void loraGroupGEMMParamFillRowReorderFusionKernel( + // Output parameters + int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs, int64_t* d_ptrs, int64_t* d_prime_ptrs, int64_t* lda, + int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime, int64_t* splitk_offsets, uint8_t* reordered_input, + // Input parameters + int32_t max_lora_count, int32_t max_lora_rank, int32_t sum_output_hidden_size, int32_t input_hidden_size, + int64_t dtype_element_size, int64_t batch_size, int64_t a_base, int64_t d_base, int64_t d_prime_base, + int32_t const* slot_counts, int32_t const* slot_ranks, int64_t const* slot_offsets, int32_t const* module_out_sizes, + int64_t const* module_out_prefix, int64_t const* b_ptrs, int64_t const* b_prime_ptrs, uint8_t const* input, + int64_t const* sorted_ids) +{ + int const linearIdx = threadIdx.x; + int const blockLinearIdx = blockIdx.x + blockIdx.y * gridDim.x; + int constexpr THREADS_PER_BLOCK = BlockDim; + + // Calculate lora_id and module_id from linearIdx + int const lora_id = linearIdx % max_lora_count; + int const module_id = linearIdx / max_lora_count; + + using BlockLoad = cub::BlockLoad; + using BlockStore = cub::BlockStore; + using BlockLoad64 = cub::BlockLoad; + using BlockStore64 = cub::BlockStore; + using BlockScan = cub::BlockScan; + using BlockStore3 = cub::BlockStore; + + __shared__ union + { + typename BlockStore3::TempStorage storage3; + typename BlockScan::TempStorage scan; + } large_shared; + + __shared__ int32_t + row_count_to_gather; // rows > row_count_to_gather belong to non-LoRA requests, write zero directly + + switch (blockLinearIdx) + { + case IN_SIZES_INDEX: + { + int32_t slot_count = slot_counts[lora_id]; + int32_t rank = slot_ranks[lora_id]; + int64_t b_ptr = b_ptrs[linearIdx % (max_lora_count * MODULE_COUNT)]; + int32_t row[3] = {0}; + if (b_ptr != 0) + { + row[0] = slot_count; + row[1] = rank; + row[2] = input_hidden_size; + } + BlockStore3(large_shared.storage3).Store(in_sizes, row, max_lora_count * MODULE_COUNT * 3); + } + break; + case OUT_SIZES_INDEX: + { + int32_t slot_count = slot_counts[lora_id]; + int32_t output_hidden_size = module_out_sizes[module_id]; + int32_t rank = slot_ranks[lora_id]; + int64_t b_ptr = b_ptrs[linearIdx % (max_lora_count * MODULE_COUNT)]; + int32_t row[3] = {0}; + if (b_ptr != 0) + { + row[0] = slot_count; + row[1] = output_hidden_size; + row[2] = rank; + } + BlockStore3(large_shared.storage3).Store(out_sizes, row, max_lora_count * MODULE_COUNT * 3); + } + break; + case LDA_INDEX: + { + int64_t input_hidden_size_64[1] = {static_cast(input_hidden_size)}; + BlockStore64().Store(lda, input_hidden_size_64, max_lora_count * MODULE_COUNT); + } + break; + case LDD_INDEX: + { + int64_t max_lora_rank_64[1] = {static_cast(max_lora_rank)}; + BlockStore64().Store(ldd, max_lora_rank_64, max_lora_count * MODULE_COUNT); + } + break; + case LDB_PRIME_INDEX: + { + int64_t rank = slot_ranks[lora_id]; + BlockStore64().Store(ldb_prime, as_singleton_array(rank), max_lora_count * MODULE_COUNT); + } + break; + case LDD_PRIME_INDEX: + { + int64_t sum_output_hidden_size_64[1] = {static_cast(sum_output_hidden_size)}; + BlockStore64().Store(ldd_prime, sum_output_hidden_size_64, max_lora_count * MODULE_COUNT); + } + break; + case A_PTRS_INDEX: + { + int64_t slot_offset = 0; + BlockLoad64().Load(slot_offsets, as_singleton_array(slot_offset), max_lora_count + 1); + if (linearIdx == max_lora_count) + { + row_count_to_gather = static_cast(slot_offset); + } + slot_offset *= input_hidden_size; + slot_offset *= dtype_element_size; + slot_offset += a_base; + for (int i = 0; i < MODULE_COUNT; i++) + { + BlockStore64().Store(a_ptrs + i * max_lora_count, as_singleton_array(slot_offset), max_lora_count); + } + } + break; + case D_PTRS_INDEX: + { + int64_t slot_offset = 0; + BlockLoad64().Load(slot_offsets, as_singleton_array(slot_offset), max_lora_count + 1); + for (int i = 0; i < MODULE_COUNT; i++) + { + int64_t offset = slot_offset; + offset += i * batch_size; + offset *= max_lora_rank; + offset *= dtype_element_size; + offset += d_base; + BlockStore64().Store(d_ptrs + i * max_lora_count, as_singleton_array(offset), max_lora_count); + } + if (linearIdx == max_lora_count) + { + row_count_to_gather = static_cast(slot_offset); + } + } + break; + case D_PRIME_PTRS_INDEX: + { + int64_t slot_offset = 0; + BlockLoad64().Load(slot_offsets, as_singleton_array(slot_offset), max_lora_count + 1); + if (linearIdx == max_lora_count) + { + row_count_to_gather = static_cast(slot_offset); + } + slot_offset *= sum_output_hidden_size; + for (int i = 0; i < MODULE_COUNT; i++) + { + int64_t offset = slot_offset; + offset += module_out_prefix[i]; + offset *= dtype_element_size; + offset += d_prime_base; + BlockStore64().Store(d_prime_ptrs + i * max_lora_count, as_singleton_array(offset), max_lora_count); + } + } + break; + case SPLITK_OFFSETS_INDEX: + { + int64_t slot_count = slot_counts[lora_id]; + int64_t rank = slot_ranks[lora_id]; + int64_t b_ptr = b_ptrs[linearIdx % (max_lora_count * MODULE_COUNT)]; + int64_t splitk_offset = (b_ptr == 0) ? 0 : (slot_count * rank); + BlockScan(large_shared.scan).ExclusiveSum(splitk_offset, splitk_offset); + BlockStore64().Store(splitk_offsets, as_singleton_array(splitk_offset), max_lora_count * MODULE_COUNT); + } + break; + } + + // Set row_count_to_gather for non-pointer blocks + switch (blockLinearIdx) + { + case A_PTRS_INDEX: + case D_PTRS_INDEX: + case D_PRIME_PTRS_INDEX: break; + default: + if (linearIdx == 0) + { + row_count_to_gather = static_cast(slot_offsets[max_lora_count]); + } + } + + int constexpr ITEM_PER_THREAD = VECTOR_LOAD_WIDTH; + using BlockStoreRow = cub::BlockStore; + + { + // Write zero to intermediate buffer and output buffer + auto intermediate_cast = reinterpret_cast(d_base); + auto model_output_cast = reinterpret_cast(d_prime_base); + + int intermediate_size = MODULE_COUNT * batch_size * max_lora_rank * dtype_element_size; + int output_size = batch_size * sum_output_hidden_size * dtype_element_size; + + uint8_t all_zeroes[ITEM_PER_THREAD] = {0}; + + int const blockOffset = THREADS_PER_BLOCK * ITEM_PER_THREAD * blockLinearIdx; + BlockStoreRow().Store(intermediate_cast + blockOffset, all_zeroes, intermediate_size - blockOffset); + BlockStoreRow().Store(model_output_cast + blockOffset, all_zeroes, output_size - blockOffset); + } + + __syncthreads(); + + // Row gather + if (blockIdx.y < batch_size) + { + using BlockLoadRow = cub::BlockLoad; + + auto const row_size = input_hidden_size * dtype_element_size; + + auto output_cast = reinterpret_cast(reordered_input); + + uint8_t tile[ITEM_PER_THREAD] = {0}; + int constexpr x_stride = THREADS_PER_BLOCK * ITEM_PER_THREAD; + int const y_stride = row_size; + int tail = row_size - blockIdx.x * x_stride; + if (blockIdx.y < row_count_to_gather) + { + auto const input_cast = reinterpret_cast(input); + auto const src_row = sorted_ids[blockIdx.y]; + BlockLoadRow().Load(input_cast + blockIdx.x * x_stride + src_row * y_stride, tile, tail); + } + BlockStoreRow().Store(output_cast + blockIdx.x * x_stride + blockIdx.y * y_stride, tile, tail); + } +} + +/** + * Launch function that instantiates the appropriate kernel based on module count. + */ +template +void launchKernelWithModuleCount(int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs, int64_t* d_ptrs, + int64_t* d_prime_ptrs, int64_t* lda, int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime, int64_t* splitk_offsets, + void* reordered_input, int32_t max_lora_count, int32_t max_lora_rank, int32_t sum_output_hidden_size, + int32_t input_hidden_size, int64_t dtype_element_size, int64_t batch_size, int64_t a_base, int64_t d_base, + int64_t d_prime_base, int32_t const* slot_counts, int32_t const* slot_ranks, int64_t const* slot_offsets, + int32_t const* module_out_sizes, int64_t const* module_out_prefix, int64_t const* b_ptrs, + int64_t const* b_prime_ptrs, void const* input, int64_t const* sorted_ids, int32_t module_count, + cudaStream_t stream) +{ + int constexpr THREADS_PER_BLOCK = BlockDim; + + // Grid dimensions for row gather + int constexpr ITEMS_PER_BLOCK = THREADS_PER_BLOCK * VECTOR_LOAD_WIDTH; + int const gridDimX = common::ceilDiv(input_hidden_size * dtype_element_size, ITEMS_PER_BLOCK); + int gridDimY = std::max( + static_cast(common::ceilDiv(static_cast(PARAM_COUNT), gridDimX)), static_cast(batch_size)); + + // calculate threads needed for writing zeros to intermediate buffer and output buffer + int const itemsPerRow = ITEMS_PER_BLOCK * gridDimX; + gridDimY = std::max(gridDimY, + common::ceilDiv(static_cast(module_count * batch_size * max_lora_rank * dtype_element_size), itemsPerRow)); + gridDimY = std::max(gridDimY, + common::ceilDiv(static_cast(batch_size * sum_output_hidden_size * dtype_element_size), itemsPerRow)); + + dim3 grid(gridDimX, gridDimY); + dim3 block(BlockDim); + + auto* reordered_input_cast = reinterpret_cast(reordered_input); + auto const* input_cast = reinterpret_cast(input); + + // Dispatch based on module count + switch (module_count) + { + case 1: + loraGroupGEMMParamFillRowReorderFusionKernel<<>>(in_sizes, out_sizes, + a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast, + max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size, + a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix, + b_ptrs, b_prime_ptrs, input_cast, sorted_ids); + break; + case 2: + loraGroupGEMMParamFillRowReorderFusionKernel<<>>(in_sizes, out_sizes, + a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast, + max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size, + a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix, + b_ptrs, b_prime_ptrs, input_cast, sorted_ids); + break; + case 3: + loraGroupGEMMParamFillRowReorderFusionKernel<<>>(in_sizes, out_sizes, + a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast, + max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size, + a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix, + b_ptrs, b_prime_ptrs, input_cast, sorted_ids); + break; + case 4: + loraGroupGEMMParamFillRowReorderFusionKernel<<>>(in_sizes, out_sizes, + a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast, + max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size, + a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix, + b_ptrs, b_prime_ptrs, input_cast, sorted_ids); + break; + default: TLLM_CHECK_WITH_INFO(false, "Unsupported module_count: %d (max 4)", module_count); + } +} + +void launchLoraGroupGEMMParamFillRowReorderFusion(int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs, + int64_t* d_ptrs, int64_t* d_prime_ptrs, int64_t* lda, int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime, + int64_t* splitk_offsets, void* reordered_input, int32_t max_lora_count, int32_t max_lora_rank, + int32_t sum_output_hidden_size, int32_t input_hidden_size, int64_t dtype_element_size, int64_t batch_size, + int64_t a_base, int64_t d_base, int64_t d_prime_base, int32_t const* slot_counts, int32_t const* slot_ranks, + int64_t const* slot_offsets, int32_t const* module_out_sizes, int64_t const* module_out_prefix, + int64_t const* b_ptrs, int64_t const* b_prime_ptrs, void const* input, int64_t const* sorted_ids, + int32_t module_count, nvinfer1::DataType dtype, cudaStream_t stream) +{ + // Determine block dimensions (1D) + // Requirements: 1) >= max_lora_count * module_count 2) >= 256 3) divisible by 32 + + int constexpr MIN_THREADS = 256; + int constexpr WARP_SIZE = 32; + + int const min_threads_needed = max_lora_count * module_count; + int const threads_per_block = std::max(MIN_THREADS, common::ceilDiv(min_threads_needed, WARP_SIZE) * WARP_SIZE); + + if (threads_per_block == 256) + { + launchKernelWithModuleCount<256>(in_sizes, out_sizes, a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, + ldd_prime, splitk_offsets, reordered_input, max_lora_count, max_lora_rank, sum_output_hidden_size, + input_hidden_size, dtype_element_size, batch_size, a_base, d_base, d_prime_base, slot_counts, slot_ranks, + slot_offsets, module_out_sizes, module_out_prefix, b_ptrs, b_prime_ptrs, input, sorted_ids, module_count, + stream); + } + else if (threads_per_block == 288) + { + launchKernelWithModuleCount<288>(in_sizes, out_sizes, a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, + ldd_prime, splitk_offsets, reordered_input, max_lora_count, max_lora_rank, sum_output_hidden_size, + input_hidden_size, dtype_element_size, batch_size, a_base, d_base, d_prime_base, slot_counts, slot_ranks, + slot_offsets, module_out_sizes, module_out_prefix, b_ptrs, b_prime_ptrs, input, sorted_ids, module_count, + stream); + } + else if (threads_per_block == 320) + { + launchKernelWithModuleCount<320>(in_sizes, out_sizes, a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, + ldd_prime, splitk_offsets, reordered_input, max_lora_count, max_lora_rank, sum_output_hidden_size, + input_hidden_size, dtype_element_size, batch_size, a_base, d_base, d_prime_base, slot_counts, slot_ranks, + slot_offsets, module_out_sizes, module_out_prefix, b_ptrs, b_prime_ptrs, input, sorted_ids, module_count, + stream); + } + else + { + TLLM_CHECK_WITH_INFO(false, + "Unsupported threads_per_block: %d (calculated from max_lora_count=%d * module_count=%d)", + threads_per_block, max_lora_count, module_count); + } + + sync_check_cuda_error(stream); +} + +} // namespace kernels + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.h b/cpp/tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.h new file mode 100644 index 0000000000..3043054ca4 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/config.h" +#include +#include +#include + +TRTLLM_NAMESPACE_BEGIN + +namespace kernels +{ + +/** + * @brief Fused kernel that fills group GEMM parameters, performs row reordering, zero fillings for CUDA graph + * compatible LoRA. + * + * @param in_sizes Output: [module_count, max_lora_count, 3] problem sizes for first GEMM + * @param out_sizes Output: [module_count, max_lora_count, 3] problem sizes for second GEMM + * @param a_ptrs Output: [module_count, max_lora_count] input matrix pointers + * @param d_ptrs Output: [module_count, max_lora_count] intermediate output pointers + * @param d_prime_ptrs Output: [module_count, max_lora_count] final output pointers + * @param lda Output: [module_count, max_lora_count] leading dimensions for A matrices + * @param ldd Output: [module_count, max_lora_count] leading dimensions for D matrices + * @param ldb_prime Output: [module_count, max_lora_count] leading dimensions for B' matrices + * @param ldd_prime Output: [module_count, max_lora_count] leading dimensions for D' matrices + * @param splitk_offsets Output: [module_count, max_lora_count] split-K work offsets + * @param reordered_input Output: [batch_size, input_hidden_size] reordered input matrix + * @param max_lora_count Maximum number of LoRA adapters + * @param max_lora_rank Maximum rank of LoRA adapters + * @param sum_output_hidden_size Sum of output hidden sizes across modules + * @param input_hidden_size Input hidden dimension + * @param dtype_element_size Size of data type in bytes + * @param batch_size Batch size + * @param a_base Base pointer for input matrices + * @param d_base Base pointer for intermediate output matrices + * @param d_prime_base Base pointer for final output matrices + * @param slot_counts Input: [max_lora_count] number of requests per LoRA slot + * @param slot_ranks Input: [max_lora_count] rank of each LoRA adapter + * @param slot_offsets Input: [max_lora_count + 1] cumulative offsets (last element = total rows) + * @param module_out_sizes Input: [module_count] output hidden size per module + * @param module_out_prefix Input: [module_count] prefix sum of output hidden sizes + * @param b_ptrs Input: [module_count, max_lora_count] weight pointers for first GEMM + * @param b_prime_ptrs Input: [module_count, max_lora_count] weight pointers for second GEMM + * @param input Input: [batch_size, input_hidden_size] original input matrix + * @param sorted_ids Input: [batch_size] indices for row reordering + * @param module_count Number of modules per layer + * @param dtype Data type of matrices + * @param stream CUDA stream + */ +void launchLoraGroupGEMMParamFillRowReorderFusion(int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs, + int64_t* d_ptrs, int64_t* d_prime_ptrs, int64_t* lda, int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime, + int64_t* splitk_offsets, void* reordered_input, int32_t max_lora_count, int32_t max_lora_rank, + int32_t sum_output_hidden_size, int32_t input_hidden_size, int64_t dtype_element_size, int64_t batch_size, + int64_t a_base, int64_t d_base, int64_t d_prime_base, int32_t const* slot_counts, int32_t const* slot_ranks, + int64_t const* slot_offsets, int32_t const* module_out_sizes, int64_t const* module_out_prefix, + int64_t const* b_ptrs, int64_t const* b_prime_ptrs, void const* input, int64_t const* sorted_ids, + int32_t module_count, nvinfer1::DataType dtype, cudaStream_t stream); + +} // namespace kernels + +TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 0dcf82da89..d642e609b9 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -563,6 +564,11 @@ void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m) nb::arg("config"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"), nb::call_guard()) .def("is_task_cached", &tb::PeftCacheManager::isTaskCached, nb::arg("taskId"), + nb::call_guard()) + .def("is_task_cached_device", &tb::PeftCacheManager::isTaskCachedDevice, nb::arg("taskId"), + nb::call_guard()) // ; + .def("ensure_batch_map_task_id", &tb::PeftCacheManager::ensureBatchMapTaskId, nb::arg("context_requests"), + nb::arg("generation_requests"), nb::arg("reset_gpu_cache") = false, nb::call_guard()); nb::class_(m, "NoOpPeftCacheManager") diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 307eee166f..35126b93c7 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -558,6 +558,11 @@ void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m) py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"), py::call_guard()) .def("is_task_cached", &tb::PeftCacheManager::isTaskCached, py::arg("taskId"), + py::call_guard()) + .def("is_task_cached_device", &tb::PeftCacheManager::isTaskCachedDevice, py::arg("taskId"), + py::call_guard()) + .def("ensure_batch_map_task_id", &tb::PeftCacheManager::ensureBatchMapTaskId, py::arg("context_requests"), + py::arg("generation_requests"), py::arg("reset_gpu_cache") = false, py::call_guard()); py::classh(m, "NoOpPeftCacheManager") diff --git a/cpp/tensorrt_llm/runtime/loraManager.h b/cpp/tensorrt_llm/runtime/loraManager.h index e713c653f1..7b18efef0c 100644 --- a/cpp/tensorrt_llm/runtime/loraManager.h +++ b/cpp/tensorrt_llm/runtime/loraManager.h @@ -42,7 +42,7 @@ public: using LoraReqTensors = std::tuple; using TaskIdType = std::int64_t; using PeftValues = std::vector; - using PeftTable = std::map>; + using PeftTable = std::unordered_map>; explicit LoraManager() {} diff --git a/cpp/tensorrt_llm/thop/loraOp.cpp b/cpp/tensorrt_llm/thop/loraOp.cpp index 08cf10decf..b35ca26086 100644 --- a/cpp/tensorrt_llm/thop/loraOp.cpp +++ b/cpp/tensorrt_llm/thop/loraOp.cpp @@ -18,13 +18,16 @@ #include "tensorrt_llm/common/cublasMMWrapper.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/kernels/cuda_graph_grouped_gemm.h" #include "tensorrt_llm/kernels/lora/lora.h" +#include "tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.h" #include "tensorrt_llm/kernels/selectiveScan/selectiveScan.h" #include "tensorrt_llm/thop/thUtils.h" namespace th = torch; namespace tk = tensorrt_llm::kernels; using tensorrt_llm::common::fmtstr; +namespace tc = tensorrt_llm::common; TRTLLM_NAMESPACE_BEGIN @@ -174,6 +177,171 @@ std::vector lora_grouped_gemm(th::Tensor const& input, th::Tensor co return output_torch; } +void lora_grouped_gemm_cuda_graph(th::Tensor const& lora_in_sizes, // [layer_module_num, max_lora_size, 3] + th::Tensor const& lora_out_sizes, // [layer_module_num, max_lora_size, 3] + th::Tensor const& a_offsets, // [layer_module_num, max_lora_size] + th::Tensor const& b_ptrs, // [layer_module_num, max_lora_size] + th::Tensor const& d_offsets, // [layer_module_num, max_lora_size] + th::Tensor const& b_prime_ptrs, // [layer_module_num, max_lora_size] + th::Tensor const& d_prime_offsets, // [layer_module_num, max_lora_size] + int64_t problem_count, + th::Tensor const& lda, // Leading dimensions for A matrices [layer_module_num, max_lora_size] + th::Tensor const& ldb, // Leading dimensions for B matrices [layer_module_num, max_lora_size] + th::Tensor const& ldd, // Leading dimensions for C matrices [layer_module_num, max_lora_size] (unused) + th::Tensor const& ldb_prime, th::Tensor const& ldd_prime, th::Tensor const& host_max_in_sizes, + th::Tensor const& host_max_out_sizes, th::Tensor const& splitk_offsets, c10::ScalarType dtype, int64_t minKN, + int64_t splitKSlices = 16) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + sync_check_cuda_error(stream); + + auto* a_ptrs_gpu = reinterpret_cast(const_cast(a_offsets.data_ptr())); + auto* d_ptrs_gpu = reinterpret_cast(const_cast(d_offsets.data_ptr())); + auto* a_prime_ptrs_gpu = reinterpret_cast(const_cast(d_offsets.data_ptr())); + auto* d_prime_ptrs_gpu = reinterpret_cast(const_cast(d_prime_offsets.data_ptr())); + + auto* problem_sizes_1_ptr = reinterpret_cast(lora_in_sizes.data_ptr()); + auto* problem_sizes_2_ptr = reinterpret_cast(lora_out_sizes.data_ptr()); + + auto* host_max_in_sizes_ptr = reinterpret_cast(host_max_in_sizes.data_ptr()); + auto* host_max_out_sizes_ptr = reinterpret_cast(host_max_out_sizes.data_ptr()); + + auto* b_ptrs_gpu = reinterpret_cast(const_cast(b_ptrs.data_ptr())); + auto* b_prime_ptrs_gpu = reinterpret_cast(const_cast(b_prime_ptrs.data_ptr())); + + auto* lda_gpu = reinterpret_cast(const_cast(lda.data_ptr())); + auto* ldb_gpu = reinterpret_cast(const_cast(ldb.data_ptr())); + auto* ldd_gpu = reinterpret_cast(const_cast(ldd.data_ptr())); + auto* ldb_prime_gpu = reinterpret_cast(const_cast(ldb_prime.data_ptr())); + auto* ldd_prime_gpu = reinterpret_cast(const_cast(ldd_prime.data_ptr())); + + auto* splitk_offsets_gpu = reinterpret_cast(const_cast(splitk_offsets.data_ptr())); + + // Get data type + nvinfer1::DataType loraRuntimeDataType; + switch (dtype) + { + case torch::kFloat16: loraRuntimeDataType = nvinfer1::DataType::kHALF; break; + case torch::kBFloat16: loraRuntimeDataType = nvinfer1::DataType::kBF16; break; + default: TORCH_CHECK(false, "Invalid dtype, only supports float16, bfloat16, got %s", c10::toString(dtype)); + } + + int const minKnInt = std::max(1, static_cast(minKN)); + + // Call CUDA Graph compatible grouped GEMM for lora_in (split-K) + if (problem_count > 0) + { + TLLM_LOG_TRACE("Start Grouped GEMM for LoRA in."); + + tk::cudaGraphSplitKGroupedGemm(problem_sizes_1_ptr, problem_count, a_ptrs_gpu, b_ptrs_gpu, + d_ptrs_gpu, // ptrC (no bias) + d_ptrs_gpu, lda_gpu, ldb_gpu, ldd_gpu, ldd_gpu, // Precomputed leading dimensions + true, // isLoraIn + loraRuntimeDataType, + static_cast(splitKSlices), // splitKSlices + minKnInt, // minKN + host_max_in_sizes_ptr, splitk_offsets_gpu, stream); + sync_check_cuda_error(stream); + + // Call CUDA Graph compatible grouped GEMM for lora_out + TLLM_LOG_TRACE("Start Grouped GEMM for LoRA out."); + tk::cudaGraphGroupedGemm(problem_sizes_2_ptr, problem_count, a_prime_ptrs_gpu, b_prime_ptrs_gpu, + d_prime_ptrs_gpu, // ptrC (no bias) + d_prime_ptrs_gpu, ldd_gpu, ldb_prime_gpu, ldd_prime_gpu, ldd_prime_gpu, // Precomputed leading dimensions + false, // isLoraIn + loraRuntimeDataType, + minKnInt, // minKN + host_max_out_sizes_ptr, stream); + sync_check_cuda_error(stream); + } + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +void lora_group_gemm_param_fill_row_reorder_fusion(th::Tensor const& in_sizes, // [module_count, max_lora_count, 3] + th::Tensor const& out_sizes, // [module_count, max_lora_count, 3] + th::Tensor const& a_ptrs, // [module_count, max_lora_count] + th::Tensor const& d_ptrs, // [module_count, max_lora_count] + th::Tensor const& d_prime_ptrs, // [module_count, max_lora_count] + th::Tensor const& lda, // [module_count, max_lora_count] + th::Tensor const& ldd, // [module_count, max_lora_count] + th::Tensor const& ldb_prime, // [module_count, max_lora_count] + th::Tensor const& ldd_prime, // [module_count, max_lora_count] + th::Tensor const& splitk_offsets, // [module_count, max_lora_count] + th::Tensor const& reordered_input, // [batch_size, input_hidden_size] + int64_t max_lora_count, int64_t max_lora_rank, int64_t sum_output_hidden_size, int64_t input_hidden_size, + int64_t batch_size, + th::Tensor const& slot_counts, // [max_lora_count] + th::Tensor const& slot_ranks, // [max_lora_count] + th::Tensor const& slot_offsets, // [max_lora_count + 1] + th::Tensor const& module_out_sizes, // [module_count] + th::Tensor const& module_out_prefix, // [module_count] + th::Tensor const& b_ptrs, // [module_count, max_lora_count] + th::Tensor const& b_prime_ptrs, // [module_count, max_lora_count] + th::Tensor const& input, // [batch_size, input_hidden_size] + th::Tensor const& sorted_ids, // [batch_size] + th::Tensor const& intermediate_buffer, // [batch_size, max_lora_rank] + th::Tensor const& output_buffer, // [batch_size, sum_output_hidden_size] + c10::ScalarType dtype) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // Validate inputs + TORCH_CHECK(in_sizes.device().is_cuda(), "in_sizes must be a CUDA tensor"); + TORCH_CHECK(out_sizes.device().is_cuda(), "out_sizes must be a CUDA tensor"); + TORCH_CHECK(reordered_input.device().is_cuda(), "reordered_input must be a CUDA tensor"); + TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor"); + + // Get module count from tensor shapes + int32_t const module_count = static_cast(in_sizes.size(0)); + + // Get data type info + nvinfer1::DataType loraRuntimeDataType; + switch (dtype) + { + case torch::kFloat16: loraRuntimeDataType = nvinfer1::DataType::kHALF; break; + case torch::kBFloat16: loraRuntimeDataType = nvinfer1::DataType::kBF16; break; + default: TORCH_CHECK(false, "Invalid dtype, only supports float16, bfloat16, got %s", c10::toString(dtype)); + } + + int64_t const dtype_element_size = input.element_size(); + + int64_t const a_base_ptr = reinterpret_cast(reordered_input.data_ptr()); + int64_t const d_base_ptr = reinterpret_cast(intermediate_buffer.data_ptr()); + int64_t const d_prime_base_ptr = reinterpret_cast(output_buffer.data_ptr()); + + tk::launchLoraGroupGEMMParamFillRowReorderFusion(reinterpret_cast(const_cast(in_sizes.data_ptr())), + reinterpret_cast(const_cast(out_sizes.data_ptr())), + reinterpret_cast(const_cast(a_ptrs.data_ptr())), + reinterpret_cast(const_cast(d_ptrs.data_ptr())), + reinterpret_cast(const_cast(d_prime_ptrs.data_ptr())), + reinterpret_cast(const_cast(lda.data_ptr())), + reinterpret_cast(const_cast(ldd.data_ptr())), + reinterpret_cast(const_cast(ldb_prime.data_ptr())), + reinterpret_cast(const_cast(ldd_prime.data_ptr())), + reinterpret_cast(const_cast(splitk_offsets.data_ptr())), + const_cast(reordered_input.data_ptr()), static_cast(max_lora_count), + static_cast(max_lora_rank), static_cast(sum_output_hidden_size), + static_cast(input_hidden_size), dtype_element_size, batch_size, a_base_ptr, d_base_ptr, + d_prime_base_ptr, reinterpret_cast(slot_counts.data_ptr()), + reinterpret_cast(slot_ranks.data_ptr()), + reinterpret_cast(slot_offsets.data_ptr()), + reinterpret_cast(module_out_sizes.data_ptr()), + reinterpret_cast(module_out_prefix.data_ptr()), + reinterpret_cast(b_ptrs.data_ptr()), reinterpret_cast(b_prime_ptrs.data_ptr()), + input.data_ptr(), reinterpret_cast(sorted_ids.data_ptr()), module_count, loraRuntimeDataType, + stream); + + sync_check_cuda_error(stream); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + } // namespace torch_ext TRTLLM_NAMESPACE_END @@ -192,9 +360,65 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) "int max_low_rank, " "int weight_index, " "bool isRemoveInputPadding) -> Tensor[]"); + + m.def( + "lora_grouped_gemm_cuda_graph(" + "Tensor lora_in_sizes, " + "Tensor lora_out_sizes, " + "Tensor a_offsets, " + "Tensor b_ptrs, " + "Tensor d_offsets, " + "Tensor b_prime_ptrs, " + "Tensor d_prime_offsets, " + "int problem_count, " + "Tensor lda, " + "Tensor ldb, " + "Tensor ldd, " + "Tensor ldb_prime, " + "Tensor ldd_prime, " + "Tensor host_max_in_sizes, " + "Tensor host_max_out_sizes, " + "Tensor splitk_offsets, " + "ScalarType dtype, " + "int minKN, " + "int splitKSlices=16) -> ()"); + + m.def( + "lora_group_gemm_param_fill_row_reorder_fusion(" + "Tensor in_sizes, " + "Tensor out_sizes, " + "Tensor a_ptrs, " + "Tensor d_ptrs, " + "Tensor d_prime_ptrs, " + "Tensor lda, " + "Tensor ldd, " + "Tensor ldb_prime, " + "Tensor ldd_prime, " + "Tensor splitk_offsets, " + "Tensor reordered_input, " + "int max_lora_count, " + "int max_lora_rank, " + "int sum_output_hidden_size, " + "int input_hidden_size, " + "int batch_size, " + "Tensor slot_counts, " + "Tensor slot_ranks, " + "Tensor slot_offsets, " + "Tensor module_out_sizes, " + "Tensor module_out_prefix, " + "Tensor b_ptrs, " + "Tensor b_prime_ptrs, " + "Tensor input, " + "Tensor sorted_ids, " + "Tensor intermediate_buffer, " + "Tensor output_buffer, " + "ScalarType dtype) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("lora_grouped_gemm", &tensorrt_llm::torch_ext::lora_grouped_gemm); + m.impl("lora_grouped_gemm_cuda_graph", &tensorrt_llm::torch_ext::lora_grouped_gemm_cuda_graph); + m.impl("lora_group_gemm_param_fill_row_reorder_fusion", + &tensorrt_llm::torch_ext::lora_group_gemm_param_fill_row_reorder_fusion); } diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 69ae313713..746ff12050 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -317,9 +317,6 @@ class Attention(nn.Module): self.fused_qkv_lora = LoraLayer([LoraModuleType.ATTENTION_QKV], [self.q_size + 2 * self.kv_size]) - self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], - [self.hidden_size]) - # Whether to fuse RoPE into the attention OP. # If true, RoPE will be applied in self.attn.forward. # If false, RoPE will be applied in self.apply_rope. diff --git a/tensorrt_llm/_torch/peft/lora/adapter_slot_manager.py b/tensorrt_llm/_torch/peft/lora/adapter_slot_manager.py new file mode 100644 index 0000000000..69f25179be --- /dev/null +++ b/tensorrt_llm/_torch/peft/lora/adapter_slot_manager.py @@ -0,0 +1,130 @@ +""" +AdapterSlotManager for managing slots that stores LoRA indices. +""" + +from collections import OrderedDict +from typing import List, Optional + +from ...pyexecutor.resource_manager import PeftCacheManager +from ...pyexecutor.scheduler import RequestList + + +class AdapterSlotManager: + """ + Manages max_num_adapters ordered slots for distinct task_ids to enable CUDA Graph compatibility. + + Each slot can hold one adapter (task_id) and maintains a consistent ordering that allows + the CUDA Graph to be captured with fixed buffer layouts. + """ + + def __init__(self, max_num_adapters: int): + """ + Initialize the AdapterSlotManager. + + Args: + max_num_adapters: Maximum number of LoRA adapters that can be active simultaneously + """ + self.max_num_adapters = max_num_adapters + + # Slot management + self.slot2task: List[Optional[int]] = [None] * max_num_adapters + self.task2slot: OrderedDict[int, int] = OrderedDict() # represent LRU order + + # State tracking + self.slots_changed = False + + def find_free_slot(self) -> int: + """ + Find a free slot. Return slot_id if found, otherwise return None. + """ + return self.slot2task.index(None) + + def remove_task(self, task_id: int) -> Optional[int]: + """ + Remove a task_id from slots. Return its slot_id if present otherwise return None. + """ + slot_id = self.task2slot.pop(task_id, None) + if slot_id is not None: + self.slots_changed = True + self.slot2task[slot_id] = None + return slot_id + + def get_or_assign_task(self, task_id: int) -> tuple[int, Optional[int]]: + """ + Assign a task_id to a slot and do LRU eviction if necessary. + If already in any slot, update LRU order. + Return: pair (assigned slot_id, evicted task_id) + """ + evicted_task = None + if task_id in self.task2slot: + self.task2slot.move_to_end(task_id) + else: + self.slots_changed = True + if len(self.task2slot) < self.max_num_adapters: + free_slot = self.find_free_slot() + self.slot2task[free_slot] = task_id + self.task2slot[task_id] = free_slot + else: + # evict lru + evicted_task, evicted_slot = self.task2slot.popitem(last=False) + self.slot2task[evicted_slot] = task_id + self.task2slot[task_id] = evicted_slot + return self.task2slot[task_id], evicted_task + + def remove_evicted_slots_in_cpp(self, peft_cache_manager: PeftCacheManager): + """ + Validate slots by removing tasks that are not cached in PeftCacheManager. + """ + for task_id in self.slot2task: + if task_id is not None: + if not peft_cache_manager.is_task_cached_device(task_id): + self.remove_task(task_id) + + def update_slots( + self, requests: RequestList, peft_cache_manager: PeftCacheManager + ) -> list[int]: + """ + Get slot mapping for all requests in a scheduled batch. + + Args: + scheduled_requests: The scheduled requests for the current batch + + Returns: + Dict mapping request_id to slot_id, with slot_id=max_num_adapters for base model requests + """ + # remove task evicted in PeftCacheManager in C++ + self.remove_evicted_slots_in_cpp(peft_cache_manager) + + # check if total number of unique tasks in the requests is not larger than max_num_adapters + tasks = [request.lora_task_id for request in requests] + unique_tasks = {t for t in tasks if t is not None} + assert len(unique_tasks) <= self.max_num_adapters, ( + f"Batch with more unique LoRA adapters ({len(unique_tasks)}) than max_num_adapters={self.max_num_adapters} " + "is not supported" + ) + + # assign slots to tasks + for i, task in enumerate(tasks): + if task is None: + tasks[i] = self.max_num_adapters + else: + tasks[i], evicted_task = self.get_or_assign_task(task) + + return tasks + + def get_slot_to_task_mapping(self) -> tuple[Optional[int], ...]: + """ + Get current slot to task mapping. + + Returns: + Tuple mapping slot_id to task_id (or None if slot is empty) + """ + return tuple(self.slot2task) + + def has_slots_changed(self) -> bool: + """Check if slot assignments have changed since last check.""" + return self.slots_changed + + def reset_slots_changed(self): + """Reset the slots_changed flag.""" + self.slots_changed = False diff --git a/tensorrt_llm/_torch/peft/lora/cuda_graph_lora_manager.py b/tensorrt_llm/_torch/peft/lora/cuda_graph_lora_manager.py new file mode 100644 index 0000000000..302e401914 --- /dev/null +++ b/tensorrt_llm/_torch/peft/lora/cuda_graph_lora_manager.py @@ -0,0 +1,175 @@ +from typing import Dict, Optional + +import torch + +from ...._utils import nvtx_range +from ....logger import logger +from ....lora_manager import LoraManager, LoraModelConfig +from ...attention_backend.interface import AttentionMetadata +from ...pyexecutor.resource_manager import PeftCacheManager +from ...pyexecutor.scheduler import ScheduledRequests +from .adapter_slot_manager import AdapterSlotManager +from .cuda_graph_lora_params import CudaGraphLoraParams +from .layer import LoraLayer + + +class CudaGraphLoraManager: + """ + Manager that coordinates adapter slots and CUDA Graph compatible LoRA parameters. + + This class bridges the gap between the current LoRA implementation and the new + CUDA Graph compatible design by managing adapter slots and preparing persistent + device tensors for group GEMM operations. + """ + + def __init__( + self, + max_lora_size: int, + max_batch_size: int, + max_lora_rank: int, + model: torch.nn.Module, + lora_model_config: Optional[LoraModelConfig], + device: str = "cuda", + ): + """ + Initialize the CUDA Graph LoRA manager. + + Args: + max_lora_size: Maximum number of LoRA adapters that can be active + max_batch_size: Maximum batch size for CUDA graphs + max_lora_rank: Maximum LoRA rank across all layers + model: Model to get layerwise LoRA info + lora_model_config: LoRA model configuration + device: Device to allocate tensors on + """ + self.max_lora_size = max_lora_size + self.max_batch_size = max_batch_size + self.max_lora_rank = max_lora_rank + self.device = device + + self.adapter_slot_manager = AdapterSlotManager(max_lora_size) + self.lora_model_config = lora_model_config + lora_target_modules = lora_model_config.lora_target_modules + self.target_modules_ids: Optional[tuple[int, ...]] = ( + tuple(map(LoraManager.LORA_MODULE_IDS.__getitem__, lora_target_modules)) + if bool(lora_target_modules) + else None + ) + if not self.target_modules_ids: + logger.debug( + "No LoRA target modules provided in LoRA config, using all modules in PyTorch Module!" + ) + + # Single CudaGraphLoraParams instance for all batch sizes. Its shape will be allocated to accommodate the max + # batch size. + self.cuda_graph_lora_params: Optional[CudaGraphLoraParams] = None + self.layer_info: ( + Dict[CudaGraphLoraParams.LoraLayerKey, CudaGraphLoraParams.LoraLayerInfo] | None + ) = None + # Initialize layer_info from model. + self._initialize_from_model(model) + self.cuda_graph_lora_params = CudaGraphLoraParams( + max_batch_size=self.max_batch_size, + max_lora_size=self.max_lora_size, + max_rank=self.max_lora_rank, + layer_info=self.layer_info, + device=self.device, + ) + + def _initialize_from_model(self, model: torch.nn.Module): + """ + Initialize LoRALayerInfo from model. + """ + self.layer_info = dict() + + def get_layer_idx( + model: torch.nn.Module, lora_module: LoraLayer, lora_module_name: str + ) -> Optional[int]: + """Find the layer index of the given LoRA module in the model.""" + module = lora_module + name = lora_module_name + while module is not None and ( + (not hasattr(module, "layer_idx")) or module.layer_idx is None + ): + name = name.rsplit(".", 1) + name = name[0] if len(name) > 1 else None + if name is not None: + module = model.get_submodule(name) + else: + module = None + if hasattr(module, "layer_idx") and module.layer_idx is not None: + return module.layer_idx + return None + + # Ignore LoRA layers without at least one of the target modules. + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + layer_idx = get_layer_idx(model, module, name) + # if target_modules_ids is None, by default enable all modules + if self.target_modules_ids and not any( + module_id in self.target_modules_ids for module_id in module.lora_module_types + ): + logger.debug(f"Layer {name} does not have any of the target modules, skipping") + continue + layer_key = CudaGraphLoraParams.LoraLayerKey( + layer_idx=layer_idx, module_ids=tuple(module.lora_module_types) + ) + assert layer_key not in self.layer_info, f"Layer {layer_key} already exists" + + self.layer_info[layer_key] = CudaGraphLoraParams.LoraLayerInfo( + module_num=len(module.lora_module_types), + output_sizes=module.output_hidden_sizes, + ) + + @nvtx_range("prepare_cuda_graph_lora_params") + def prepare_cuda_graph_lora_params( + self, + scheduled_requests: "ScheduledRequests", + attn_metadata: "AttentionMetadata", + peft_cache_manager: PeftCacheManager, + ) -> Optional[Dict]: + """ + Prepare LoRA parameters from scheduled requests. + + Args: + scheduled_requests: The scheduled requests for the current batch + attn_metadata: Attention metadata containing batch information + peft_table: PEFT table from cache manager mapping task_id to layer-module-configs + + Returns: + LoRA parameters dictionary. + """ + assert len(scheduled_requests.context_requests) == 0, ( + "Context requests are not supported with LoRA CUDA Graph path. " + f"Have {len(scheduled_requests.context_requests)} context requests" + ) + request_list = scheduled_requests.generation_requests + + peft_table = peft_cache_manager.get_and_reset_batch_peft_table() + + # Update slot manager and get slot assignments for this batch + request_slot_ids = self.adapter_slot_manager.update_slots(request_list, peft_cache_manager) + + cuda_graph_lora_params = self.cuda_graph_lora_params + cuda_graph_lora_params.update_sorted_indices(request_slot_ids) + + # Get current slot to task mapping + slot2task = self.adapter_slot_manager.get_slot_to_task_mapping() + + # Update weight pointers if slot assignments changed + if self.adapter_slot_manager.has_slots_changed(): + cuda_graph_lora_params.update_weight_pointers(peft_table, slot2task) + self.adapter_slot_manager.reset_slots_changed() + + # Update GEMM sizes and prefix sums using batch + cuda_graph_lora_params.update_slots_params(batch_slot_ids=request_slot_ids) + + lora_params = { + "cuda_graph_params": cuda_graph_lora_params, + "host_request_types": attn_metadata.host_request_types, + "prompt_lens_cpu": attn_metadata.prompt_lens_cpu, + "num_seqs": attn_metadata.num_seqs, + "use_cuda_graph_mode": True, # Flag to indicate new mode + } + + return lora_params diff --git a/tensorrt_llm/_torch/peft/lora/cuda_graph_lora_params.py b/tensorrt_llm/_torch/peft/lora/cuda_graph_lora_params.py new file mode 100644 index 0000000000..16c3d5d3d8 --- /dev/null +++ b/tensorrt_llm/_torch/peft/lora/cuda_graph_lora_params.py @@ -0,0 +1,341 @@ +from collections import namedtuple +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch + + +@dataclass +class LoraLayerParams: + """ + Parameters for a single LoRA layer. + All tensors are persistent device tensors that can be updated outside of graph replay. + """ + + # Weight pointers + # Shape: [layer_module_num, max_lora_size] + d_b_ptrs: torch.Tensor # Lora_in weight pointers + d_b_prime_ptrs: torch.Tensor # Lora_out weight pointers + h_b_ptrs: torch.Tensor # Lora_in weight pointers in host + h_b_prime_ptrs: torch.Tensor # Lora_out weight pointers in host + + d_output_sizes: torch.Tensor + d_output_sizes_offset: torch.Tensor + h_output_sizes: torch.Tensor + h_output_sizes_offset: torch.Tensor + + +class CudaGraphLoraParams: + """ + CUDA Graph compatible LoRA parameters for all layers and batch management. + + This structure maintains persistent device tensors that can be updated outside + of CUDA Graph replay to support different LoRA combinations per batch. + """ + + LoraLayerKey = namedtuple("LoraLayerKey", ["layer_idx", "module_ids"]) + + PTR_DTYPE = torch.int64 + LD_DTYPE = torch.int64 + SIZES_DTYPE = torch.int32 + + @dataclass + class LoraLayerInfo: + module_num: int = 0 + output_sizes: List[int] | torch.Tensor | None = None + input_hidden_size: int = 0 + + def is_enabled(self) -> bool: + return self.input_hidden_size > 0 + + def __init__( + self, + max_batch_size: int, + max_lora_size: int, + max_rank: int, + layer_info: Dict[LoraLayerKey, LoraLayerInfo], + device: str = "cuda", + ): + """ + Initialize CUDA Graph compatible LoRA parameters. + + Args: + max_batch_size: Maximum batch size for this graph + max_lora_size: Maximum number of LoRA adapters + max_rank: Maximum rank for all layers + layers_info: Layer information for each layer + device: Device to allocate tensors on + dtype: Data type for size and offset tensors + """ + self.max_batch_size = max_batch_size + self.max_lora_size = max_lora_size + self.max_rank = max_rank + self.layer_info = layer_info + self.layer_module2key = self._calculate_layer_module2key() + self.device = device + + self.layer_params: Dict[self.LoraLayerKey, LoraLayerParams] = dict() + + # sorted indices using slot ids as keys, mainly to group requests with the same slot id together in a batch + self.sorted_ids = torch.zeros(max_batch_size, dtype=torch.int64, device=device) + self.sorted_ids_host = torch.zeros_like(self.sorted_ids, device="cpu", pin_memory=True) + + # persistent values for gen-only batch with cuda graph + self.persistent_sorted_ids = self.sorted_ids + + self.slot_ids = torch.zeros(max_batch_size, dtype=torch.int64, device=device) + + self.slot_counts = torch.zeros(max_lora_size, dtype=torch.int32, device=device) + self.slot_counts_host = torch.zeros_like(self.slot_counts, device="cpu", pin_memory=True) + self.slot_offsets_full = torch.zeros(max_lora_size + 1, dtype=torch.int64, device=device) + self.slot_offsets = self.slot_offsets_full[:-1] + self.slot_offsets_full_host = torch.zeros_like( + self.slot_offsets_full, device="cpu", pin_memory=True + ) + + self.slot_ranks = torch.zeros(max_lora_size, dtype=torch.int32, device=device) + self.slot_ranks_host = torch.zeros_like(self.slot_ranks, device="cpu", pin_memory=True) + + for key, info in self.layer_info.items(): + assert ( + info.module_num > 0 + and info.output_sizes is not None + and len(info.output_sizes) == info.module_num + ) + # Allocate layer parameters + self.layer_params[key] = self._allocate_layer_params( + key, info.module_num, info.output_sizes + ) + + def _calculate_layer_module2key(self) -> Dict[Tuple[int, int], LoraLayerKey]: + layer_module2key = dict() + for key in self.layer_info.keys(): + layer_id = key.layer_idx + module_ids = key.module_ids + for module_id in module_ids: + layer_module2key[(layer_id, module_id)] = key + return layer_module2key + + def _allocate_layer_params( + self, key: LoraLayerKey, layer_module_num: int, module_output_sizes: torch.Tensor + ) -> LoraLayerParams: + """ + Create LoraLayerParams for a specific layer. + + Args: + key: Key of the layer + layer_module_num: Number of modules in this layer + module_output_sizes: Output sizes for each module in this layer + + Returns: + LoraLayerParams for the specified layer + """ + # GEMM parameter tensors only need max_lora_size (no dummy slot for base model) + # Base model requests are handled separately and don't participate in GEMM operations + shape_2d = (layer_module_num, self.max_lora_size) + + output_hidden_sizes = torch.tensor(module_output_sizes, dtype=self.SIZES_DTYPE) + output_hidden_sizes_device = output_hidden_sizes.to(device="cuda") + + output_sizes_offset = self.get_offset_from_counts(output_hidden_sizes).to( + dtype=self.PTR_DTYPE + ) # [num_layer_modules] + output_sizes_offset_device = output_sizes_offset.to(device="cuda") + + return LoraLayerParams( + # Weight pointers - managed by PEFT cache manager + d_b_ptrs=torch.zeros(shape_2d, dtype=torch.int64, device=self.device), + d_b_prime_ptrs=torch.zeros(shape_2d, dtype=torch.int64, device=self.device), + h_b_ptrs=torch.zeros(shape_2d, dtype=torch.int64, pin_memory=True), + h_b_prime_ptrs=torch.zeros(shape_2d, dtype=torch.int64, pin_memory=True), + d_output_sizes=output_hidden_sizes_device, + d_output_sizes_offset=output_sizes_offset_device, + h_output_sizes=output_hidden_sizes, + h_output_sizes_offset=output_sizes_offset, + ) + + @staticmethod + def get_sorted_indices(slot_ids: List[int]) -> torch.Tensor: + """ + Get sorted indices for the given slot IDs. + """ + slot_ids = torch.tensor(slot_ids, dtype=torch.int64) + + # Compute sorted indices for gather/scatter operations + sorted_slot_ids, sorted_indices = torch.sort(slot_ids, stable=True) + return sorted_indices + + def update_sorted_indices(self, slot_ids: List[int]): + """ + Update slot IDs for the current batch and compute sorted indices. + + Args: + slot_ids: List of slot IDs for each token in the batch + actual_batch_size: Actual batch size (may be less than max_batch_size) + """ + actual_batch_size = len(slot_ids) + assert actual_batch_size <= self.max_batch_size, ( + f"Actual batch size {actual_batch_size} exceeds max {self.max_batch_size}" + ) + sorted_indices = self.get_sorted_indices(slot_ids) + + # Update sorted_ids tensor with the computed indices + assert actual_batch_size <= self.max_batch_size, ( + f"CudaGraphLoraParams: Actual batch size {actual_batch_size} exceeds max {self.max_batch_size}!" + ) + if actual_batch_size <= self.max_batch_size: + # if can fit in persistent, use it + self.sorted_ids = self.persistent_sorted_ids + sorted_ids_host = self.sorted_ids_host[:actual_batch_size] + sorted_ids_host.copy_(sorted_indices) + self.sorted_ids[:actual_batch_size].copy_(sorted_ids_host, non_blocking=True) + else: + # otherwise not an gen-only batch, use new allocated sorted_ids + self.sorted_ids = sorted_indices.to(device=self.device) + + def update_weight_pointers( + self, peft_table: Dict[int, List], slot_to_task_mapping: tuple[Optional[int], ...] + ): + """ + Update weight pointers from PEFT cache manager. + + Args: + peft_table: PEFT table from cache manager containing weight pointers, map task id to list of layer + module configs + slot_to_task_mapping: Mapping from slot_id to task_id, tuple of None for empty slots + """ + + # get slot ranks + # assume ranks are the same for a given slot, + # input_hidden_size are the same within a layer + # output sizes are the same for all slots with the same module + def zero_out_weight_pointers(slot_id: int): + """ + Zero out all weight pointers for a given slot_id for all layers + """ + for layer_param in self.layer_params.values(): + layer_param.h_b_ptrs[:, slot_id] = 0 + layer_param.h_b_prime_ptrs[:, slot_id] = 0 + + for slot_id in range(self.max_lora_size): + task_id = slot_to_task_mapping[slot_id] + if task_id is None: # empty slot + self.slot_ranks_host[slot_id] = 0 + zero_out_weight_pointers(slot_id) + elif ( + task_id not in peft_table + ): # task has not changed in the slot, retain old rank / weight pointers + continue + else: # task might have changed in the slot, update its rank + task_configs = peft_table[task_id] + config = task_configs[0] # assume all layerModuleConfigs have the same rank + self.slot_ranks_host[slot_id] = config.adapter_size + + zero_out_weight_pointers( + slot_id + ) # in case new task in slot do not have LoRA adapter for some module in some layer + for config in task_configs: + layer_id = config.layer_id + module_id = config.module_id + key = self.layer_module2key[(layer_id, module_id)] + layer_param = self.layer_params[key] + local_module_id = key.module_ids.index(module_id) + + assert key in self.layer_params, ( + f"Layer {layer_id} not found in layer_params, assumption that all LoRA has their adapters on " + "the same layers is broken" + ) + + # Validate LoRA rank + rank = config.adapter_size + assert rank <= self.max_rank, ( + f"LoRA rank {rank} in layer {layer_id} exceeds configured max_rank {self.max_rank}. " + ) + + layer_param.h_b_ptrs[local_module_id, slot_id] = config.weights_in_pointer + layer_param.h_b_prime_ptrs[local_module_id, slot_id] = ( + config.weights_out_pointer + ) + + self.slot_ranks.copy_(self.slot_ranks_host, non_blocking=True) + + for layer_param in self.layer_params.values(): + layer_param.d_b_ptrs.copy_(layer_param.h_b_ptrs, non_blocking=True) + layer_param.d_b_prime_ptrs.copy_(layer_param.h_b_prime_ptrs, non_blocking=True) + + @staticmethod + def get_offset_from_counts( + counts: torch.Tensor, full: bool = False, out: torch.Tensor = None + ) -> torch.Tensor: + if out is None: + if full: + offset = torch.empty(counts.shape[0] + 1, dtype=torch.int64, device=counts.device) + else: + offset = torch.empty(counts.shape[0], dtype=torch.int64, device=counts.device) + else: + assert (full and out.shape[0] == counts.shape[0] + 1) or ( + (not full) and out.shape[0] == counts.shape[0] + ) + offset = out + + offset[0] = 0 + + if full: + offset[1:] = counts + else: + offset[1:] = counts[:-1] + offset[1:].cumsum_(dim=0) + return offset + + @staticmethod + def get_slot_counts(batch_slot_ids: List[int], max_lora_size: int) -> torch.Tensor: + """ + Get the number of tokens for each slot_id in the batch. + """ + slot_counts = torch.bincount( + torch.tensor(batch_slot_ids, dtype=torch.int32), minlength=max_lora_size + ) + assert slot_counts.size(0) <= max_lora_size + 1 + slot_counts = slot_counts[:max_lora_size] + return slot_counts + + def update_slots_params(self, batch_slot_ids: List[int]): + """ + Update GEMM sizes and buffer offsets based on current batch composition. + + Args: + batch_slot_ids: Slot IDs for each token in the batch + """ + slot_counts = self.get_slot_counts(batch_slot_ids, self.max_lora_size) + self.slot_counts_host.copy_(slot_counts) + self.get_offset_from_counts(slot_counts, full=True, out=self.slot_offsets_full_host) + self.slot_counts.copy_(self.slot_counts_host, non_blocking=True) + self.slot_offsets_full.copy_(self.slot_offsets_full_host, non_blocking=True) + + def get_problem_count(self, layer_key: LoraLayerKey) -> int: + """ + Get the number of GEMM problems for a layer. + + Args: + layer_key: Key of the layer + + Returns: + Number of GEMM problems (layer_module_num * max_lora_size) + Returns 0 if layer has no LoRA modules + Note: Only actual LoRA slots are counted, not the dummy base model slot + """ + if layer_key not in self.layer_params: + return 0 # Layer has no LoRA modules + return self.layer_info[layer_key].module_num * self.max_lora_size + + def get_layer_params(self, layer_key: LoraLayerKey) -> Optional[LoraLayerParams]: + """ + Get LoRA parameters for a specific layer. + + Args: + layer_key: Key of the layer + + Returns: + LoraLayerParams for the specified layer, or None if layer has no LoRA modules + """ + return self.layer_params.get(layer_key) diff --git a/tensorrt_llm/_torch/peft/lora/layer.py b/tensorrt_llm/_torch/peft/lora/layer.py index 2c8bc5e2f5..1312f7e37a 100644 --- a/tensorrt_llm/_torch/peft/lora/layer.py +++ b/tensorrt_llm/_torch/peft/lora/layer.py @@ -1,8 +1,48 @@ +from dataclasses import dataclass from enum import IntEnum from typing import Dict, List, Optional import torch +from .cuda_graph_lora_params import CudaGraphLoraParams + + +@dataclass +class GroupedGemmParamsOutput: + in_sizes: Optional[torch.Tensor] = None + out_sizes: Optional[torch.Tensor] = None + a_offset: Optional[torch.Tensor] = None + d_offset: Optional[torch.Tensor] = None + d_prime_offset: Optional[torch.Tensor] = None + lda: Optional[torch.Tensor] = None + ldb: Optional[torch.Tensor] = None + ldd: Optional[torch.Tensor] = None + ldb_prime: Optional[torch.Tensor] = None + ldd_prime: Optional[torch.Tensor] = None + splitk_offsets: Optional[torch.Tensor] = None + reordered_input: Optional[torch.Tensor] = None + + +@dataclass +class GroupedGemmParamsInput: + x: torch.Tensor + output_buffer: torch.Tensor + intermediate_buffer: torch.Tensor + max_lora_size: int + max_rank: int + slot_counts: torch.Tensor + slot_ranks: torch.Tensor + slot_offsets_full: torch.Tensor + b_ptrs: torch.Tensor + b_prime_ptrs: torch.Tensor + sorted_ids: torch.Tensor + output_hidden_sizes: torch.Tensor + output_sizes_offset: torch.Tensor + + @property + def slot_offsets(self): + return self.slot_offsets_full[:-1] + class LoraModuleType(IntEnum): """Enum class representing different types of modules that can have LoRA adapters. @@ -100,57 +140,385 @@ class LoraLayer(torch.nn.Module): ) -> Optional[torch.Tensor]: if bool(lora_params): - lora_ranks = [] - lora_weight_pointers = [] - active_lora_module_ids = [] - for module_idx in self.lora_module_types: - module_idx = int(module_idx) - if module_idx in lora_params[layer_idx]: - active_lora_module_ids.append(module_idx) - lora_ranks.append( - lora_params[layer_idx][module_idx]['adapter_size']) - lora_weight_pointers.append( - lora_params[layer_idx][module_idx]['weight_pointers']) + # Check if we're using CUDA Graph mode + use_cuda_graph_mode = lora_params.get('use_cuda_graph_mode', False) - num_seqs = lora_params['num_seqs'] - - if len(active_lora_module_ids) == 0: - return None + if use_cuda_graph_mode: + return self._forward_cuda_graph_mode(x, lora_params, layer_idx) else: - lora_outputs = torch.ops.trtllm.lora_grouped_gemm( - x, - lora_params['host_request_types'][:num_seqs], - lora_ranks, - lora_weight_pointers, - lora_params['prompt_lens_cpu'][:num_seqs], - self.output_hidden_sizes, - False, # transA - True, # transB - max([r.max() for r in lora_ranks]), - 0, - True, # TODO smor- should be lora_params["remove_input_padding"], support in loraOp as well - ) - if isinstance(lora_outputs, torch.Tensor): - return lora_outputs - else: - # For multiple LoRA modules, some might not be executed in grouped gemm. - # For those modules not executed, we create zero tensors with matching dimensions. - # Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order. - lora_output = [] - for module_idx in self.lora_module_types: - if int(module_idx) in active_lora_module_ids: - lora_output.append(lora_outputs.pop(0)) - else: - lora_output.append( - torch.zeros(list(x.shape[:-1]) + [ - self.output_hidden_sizes[ - self.lora_module_types.index( - module_idx)] - ], - dtype=x.dtype, - device=x.device)) - lora_output = torch.cat(lora_output, dim=-1) - return lora_output - + return self._forward_eager_mode(x, lora_params, layer_idx) else: return None + + def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput): + device = input.x.device + bs, input_hidden_size = input.x.shape + shape_2d = (len(self.lora_module_types), input.max_lora_size + ) # [num_layer_modules, max_lora_size] + shape_3d = shape_2d + (3, ) + sum_out_sizes = sum(self.output_hidden_sizes) + + input.output_buffer.fill_(0) + input.intermediate_buffer.fill_(0) + + # reorder input + reordered_input = torch.index_select(input.x, 0, input.sorted_ids[:bs]) + + # a [bs, hidden] + lda = torch.full(shape_2d, + input_hidden_size, + dtype=CudaGraphLoraParams.LD_DTYPE, + device=device) + + # b [input_hidden_size, lora_rank] + ldb = lda + + # a_prime / d [num_layer_modules, bs, max_rank] + ldd = torch.full(shape_2d, + input.max_rank, + dtype=CudaGraphLoraParams.LD_DTYPE, + device=device) + + # b_prime [lora_rank, module_output_size] + ldb_prime = input.slot_ranks.unsqueeze(0).to( + dtype=CudaGraphLoraParams.LD_DTYPE).repeat(shape_2d[0], 1) + + # d_prime [bs, sum_of_each_module_output_sizes] + ldd_prime = torch.full(shape_2d, + sum_out_sizes, + dtype=CudaGraphLoraParams.LD_DTYPE, + device=device) + + # reordered a [bs, hidden], each module has the same offset + a_offset = input.slot_offsets * input_hidden_size + a_offset = a_offset.unsqueeze(0).repeat(shape_2d[0], 1) + + # d [num_layer_modules, bs, max_rank] + d_offset = (input.slot_offsets.unsqueeze(0) + torch.arange( + shape_2d[0], device=device, dtype=CudaGraphLoraParams.PTR_DTYPE). + unsqueeze(1) * bs) * input.max_rank + + # d' [bs, sum_of_each_module_output_sizes] + bs_offset = input.slot_offsets.unsqueeze(0) # [1, max_lora_size] + bs_offset = bs_offset * sum_out_sizes + out_offset = input.output_sizes_offset.unsqueeze( + 1) # [num_layer_modules, 1] + d_prime_offset = bs_offset + out_offset + + # sizes + in_sizes = torch.empty(shape_3d, + dtype=CudaGraphLoraParams.SIZES_DTYPE, + device=device) + out_sizes = torch.empty_like(in_sizes) + + slot_counts = input.slot_counts.unsqueeze(0) # [1, max_lora_size] + ranks = input.slot_ranks.unsqueeze(0) # [1, max_lora_size] + output_hidden_sizes = input.output_hidden_sizes.unsqueeze( + 1) # [num_layer_modules, 1] + + in_sizes[:, :, 0] = slot_counts + in_sizes[:, :, 1] = ranks + in_sizes[:, :, 2] = input_hidden_size + + out_sizes[:, :, 0] = slot_counts + out_sizes[:, :, 1] = output_hidden_sizes + out_sizes[:, :, 2] = ranks + + # disable unused modules / lora with ptr being zeros + in_sizes *= (input.b_ptrs != 0).unsqueeze(-1) + out_sizes *= (input.b_prime_ptrs != 0).unsqueeze(-1) + + # splitk_offsets: [num_layer_modules, max_lora_size] + # splitk offtsets (m * n) for the first grouped gemm with (m, n, k) = (slot_counts, slot_ranks, input_hidden_size) + splitk_offsets = torch.zeros(shape_2d, + dtype=CudaGraphLoraParams.LD_DTYPE, + device=device) + + splitk_offsets.view(-1)[1:] = in_sizes.view(-1, 3)[:-1, 0] # = M + splitk_offsets.view(-1)[1:] *= in_sizes.view(-1, 3)[:-1, 1] # *= N + splitk_offsets.view(-1).cumsum_(dim=0) + + # add base addresses to offset tensors on GPU + dtype_element_size = input.x.element_size() + a_offset *= dtype_element_size + a_offset += reordered_input.data_ptr() + + d_offset *= dtype_element_size + d_offset += input.intermediate_buffer.data_ptr() + + d_prime_offset *= dtype_element_size + d_prime_offset += input.output_buffer.data_ptr() + + return GroupedGemmParamsOutput(in_sizes=in_sizes, + out_sizes=out_sizes, + a_offset=a_offset, + d_offset=d_offset, + d_prime_offset=d_prime_offset, + lda=lda, + ldb=ldb, + ldd=ldd, + ldb_prime=ldb_prime, + ldd_prime=ldd_prime, + splitk_offsets=splitk_offsets, + reordered_input=reordered_input) + + def _prepare_grouped_gemm_buffers_fused(self, + input: GroupedGemmParamsInput): + device = input.x.device + bs, input_hidden_size = input.x.shape + shape_2d = (len(self.lora_module_types), input.max_lora_size + ) # [num_layer_modules, max_lora_size] + shape_3d = shape_2d + (3, ) + sum_out_sizes = sum(self.output_hidden_sizes) + + in_sizes = torch.empty(shape_3d, + dtype=CudaGraphLoraParams.SIZES_DTYPE, + device=device) + out_sizes = torch.empty_like(in_sizes) + a_offset = torch.empty(shape_2d, + dtype=CudaGraphLoraParams.PTR_DTYPE, + device=device) + d_offset = torch.empty_like(a_offset) + d_prime_offset = torch.empty_like(a_offset) + lda = torch.empty(shape_2d, + dtype=CudaGraphLoraParams.LD_DTYPE, + device=device) + ldb = lda + ldd = torch.empty_like(lda) + ldb_prime = torch.empty_like(lda) + ldd_prime = torch.empty_like(lda) + splitk_offsets = torch.empty(shape_2d, + dtype=CudaGraphLoraParams.LD_DTYPE, + device=device) + reordered_input = torch.empty_like(input.x) + torch.ops.trtllm.lora_group_gemm_param_fill_row_reorder_fusion( + # output parameters + in_sizes, + out_sizes, + a_offset, + d_offset, + d_prime_offset, + lda, + ldd, + ldb_prime, + ldd_prime, + splitk_offsets, + reordered_input, + + # input parameters + input.max_lora_size, + input.max_rank, + sum_out_sizes, + input_hidden_size, + bs, # batch_size + input.slot_counts, + input.slot_ranks, + input.slot_offsets, + input.output_hidden_sizes, + input.output_sizes_offset, + input.b_ptrs, + input.b_prime_ptrs, + input.x, + input.sorted_ids[:bs], + input.intermediate_buffer, + input.output_buffer, + input.x.dtype) + + return GroupedGemmParamsOutput(in_sizes=in_sizes, + out_sizes=out_sizes, + a_offset=a_offset, + d_offset=d_offset, + d_prime_offset=d_prime_offset, + lda=lda, + ldb=ldb, + ldd=ldd, + ldb_prime=ldb_prime, + ldd_prime=ldd_prime, + splitk_offsets=splitk_offsets, + reordered_input=reordered_input) + + def _prepare_max_sizes_cpu(self, + cuda_graph_lora_params: CudaGraphLoraParams, + layer_key: CudaGraphLoraParams.LoraLayerKey, + bs: int, input_hidden_size: int): + layer_params = cuda_graph_lora_params.get_layer_params(layer_key) + shape_2d = (len(self.lora_module_types), + cuda_graph_lora_params.max_lora_size + ) # [num_layer_modules, max_lora_size] + shape_3d = shape_2d + (3, ) + # dummy max sizes, on CPU + host_max_in_sizes = torch.empty( + shape_3d, dtype=CudaGraphLoraParams.SIZES_DTYPE + ) # m: batch_size, n: max_lora_rank, k: input_hidden_size + host_max_out_sizes = torch.empty_like( + host_max_in_sizes + ) # m: batch_size, n: max_output_hidden_size, k: max_lora_rank + host_max_in_sizes[:, :, 0] = bs + host_max_in_sizes[:, :, 1] = cuda_graph_lora_params.max_rank + host_max_in_sizes[:, :, 2] = input_hidden_size + + host_max_out_sizes[:, :, 0] = bs + host_max_out_sizes[:, :, 1] = layer_params.h_output_sizes.unsqueeze(1) + host_max_out_sizes[:, :, 2] = cuda_graph_lora_params.max_rank + + return host_max_in_sizes, host_max_out_sizes + + def _forward_cuda_graph_mode( + self, + x: torch.Tensor, + lora_params: Dict, + layer_idx: int, + ) -> Optional[torch.Tensor]: + """ + Forward pass using CUDA Graph compatible LoRA parameters. + + Args: + x: Input tensor + lora_params: CUDA Graph compatible LoRA parameters + layer_idx: Current layer index + + Returns: + LoRA output tensor or None + """ + + cuda_graph_params: CudaGraphLoraParams = lora_params.get( + 'cuda_graph_params') + # Get layer-specific parameters + layer_key = CudaGraphLoraParams.LoraLayerKey( + layer_idx=layer_idx, module_ids=tuple(self.lora_module_types)) + + if not cuda_graph_params or not cuda_graph_params.layer_info or layer_key not in cuda_graph_params.layer_info: + return None + + layer_params = cuda_graph_params.get_layer_params(layer_key) + + # Skip layers that don't have LoRA modules + if layer_params is None: + return 0 # Pass-through for layers without LoRA modules + + batch_size, hidden_size = x.shape[0], x.shape[-1] + num_layer_modules = len(self.lora_module_types) + max_rank = cuda_graph_params.max_rank + total_output_size = sum(self.output_hidden_sizes) + min_kn = min( + hidden_size, 8, max_rank + ) # TODO: hardcode to 8 for now, for alignments in kernels, might have alignment error if rank is less than 8! + + output_buffer = torch.empty(batch_size, + total_output_size, + dtype=x.dtype, + device=x.device) + + host_max_in_sizes, host_max_out_sizes = self._prepare_max_sizes_cpu( + cuda_graph_params, layer_key, batch_size, hidden_size) + + # Intermediate buffer: [num_layer_modules, batch_size, max_rank] + intermediate_buffer = torch.empty( + [num_layer_modules, batch_size, max_rank], + dtype=x.dtype, + device=x.device) + + params_fill_input = GroupedGemmParamsInput( + x=x, + output_buffer=output_buffer, + intermediate_buffer=intermediate_buffer, + max_lora_size=cuda_graph_params.max_lora_size, + max_rank=cuda_graph_params.max_rank, + slot_counts=cuda_graph_params.slot_counts, + slot_ranks=cuda_graph_params.slot_ranks, + slot_offsets_full=cuda_graph_params.slot_offsets_full, + b_ptrs=layer_params.d_b_ptrs, + b_prime_ptrs=layer_params.d_b_prime_ptrs, + sorted_ids=cuda_graph_params.sorted_ids, + output_hidden_sizes=layer_params.d_output_sizes, + output_sizes_offset=layer_params.d_output_sizes_offset) + grouped_gemm_params = self._prepare_grouped_gemm_buffers_fused( + params_fill_input) + + torch.ops.trtllm.lora_grouped_gemm_cuda_graph( + grouped_gemm_params.in_sizes, grouped_gemm_params.out_sizes, + grouped_gemm_params.a_offset, layer_params.d_b_ptrs, + grouped_gemm_params.d_offset, layer_params.d_b_prime_ptrs, + grouped_gemm_params.d_prime_offset, + cuda_graph_params.get_problem_count(layer_key), + grouped_gemm_params.lda, grouped_gemm_params.ldb, + grouped_gemm_params.ldd, grouped_gemm_params.ldb_prime, + grouped_gemm_params.ldd_prime, host_max_in_sizes, + host_max_out_sizes, grouped_gemm_params.splitk_offsets, + grouped_gemm_params.reordered_input.dtype, min_kn) + + # TODO: move to kernel + restored_output = torch.zeros_like(output_buffer) + restored_output.index_copy_(0, + cuda_graph_params.sorted_ids[:batch_size], + output_buffer) + return restored_output + + def _forward_eager_mode( + self, + x: torch.Tensor, + lora_params: Dict, + layer_idx: int, + ) -> Optional[torch.Tensor]: + """ + Eager-mode forward pass using the original LoRA implementation. + + Args: + x: Input tensor + lora_params: LoRA parameters for eager mode + layer_idx: Current layer index + + Returns: + LoRA output tensor or None + """ + lora_ranks = [] + lora_weight_pointers = [] + active_lora_module_ids = [] + + for module_idx in self.lora_module_types: + module_idx = int(module_idx) + if module_idx in lora_params[layer_idx]: + active_lora_module_ids.append(module_idx) + lora_ranks.append( + lora_params[layer_idx][module_idx]['adapter_size']) + lora_weight_pointers.append( + lora_params[layer_idx][module_idx]['weight_pointers']) + + num_seqs = lora_params['num_seqs'] + + if len(active_lora_module_ids) == 0: + return None + else: + lora_outputs = torch.ops.trtllm.lora_grouped_gemm( + x, + lora_params['host_request_types'][:num_seqs], + lora_ranks, + lora_weight_pointers, + lora_params['prompt_lens_cpu'][:num_seqs], + self.output_hidden_sizes, + False, # transA + True, # transB + max([r.max() for r in lora_ranks]), + 0, + True, # TODO smor- should be lora_params["remove_input_padding"], support in loraOp as well + ) + if isinstance(lora_outputs, torch.Tensor): + return lora_outputs + else: + # For multiple LoRA modules, some might not be executed in grouped gemm. + # For those modules not executed, we create zero tensors with matching dimensions. + # Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order. + lora_output = [] + for module_idx in self.lora_module_types: + if int(module_idx) in active_lora_module_ids: + lora_output.append(lora_outputs.pop(0)) + else: + lora_output.append( + torch.zeros(list(x.shape[:-1]) + [ + self.output_hidden_sizes[ + self.lora_module_types.index(module_idx)] + ], + dtype=x.dtype, + device=x.device)) + lora_output = torch.cat(lora_output, dim=-1) + return lora_output diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index b2d31786ef..898fd1575b 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -834,6 +834,8 @@ def create_py_executor_instance( lora_config.lora_target_modules, lora_config.trtllm_modules_to_hf_modules, lora_config.swap_gate_up_proj_lora_b_weight) + if isinstance(model_engine, PyTorchModelEngine): + model_engine._init_cuda_graph_lora_manager(lora_config) resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager( max_num_sequences) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 9145509d3b..62a386d585 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -533,9 +533,6 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): self.py_decoding_iter = 0 self.is_attention_dp_dummy = False self.is_cuda_graph_dummy = False - self.py_lora_task_layer_module_configs: list[ - tensorrt_llm.bindings.internal.runtime. - TaskLayerModuleConfig] | None = None self.py_kv_transfer_start_time = None self.py_kv_transfer_timed_out = False diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index a1890da391..488079011e 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -16,6 +16,7 @@ import torch._dynamo.config import tensorrt_llm.bindings.internal.userbuffers as ub from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc, torch_dtype_to_str, trace_func) +from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig from tensorrt_llm.inputs.multimodal import (MultimodalParams, MultimodalRuntimeData) from tensorrt_llm.inputs.registry import (create_input_processor, @@ -44,6 +45,7 @@ from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids from ..models.modeling_utils import DecoderModelForCausalLM from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer, MoeLoadBalancerIterContext) +from ..peft.lora.cuda_graph_lora_manager import CudaGraphLoraManager from ..speculative import (SpecMetadata, get_num_extra_kv_tokens, get_spec_metadata, update_spec_config_from_model_config) @@ -62,7 +64,8 @@ from .layerwise_nvtx_marker import LayerwiseNvtxMarker from .llm_request import LlmRequest, get_draft_token_length from .model_loader import ModelLoader, _construct_checkpoint_loader from .resource_manager import (BaseResourceManager, KVCacheManager, - ResourceManager, ResourceManagerType) + PeftCacheManager, ResourceManager, + ResourceManagerType) from .sampler import SampleStateTensors from .scheduler import ScheduledRequests @@ -449,6 +452,9 @@ class PyTorchModelEngine(ModelEngine): ) self.cuda_graph_runner = CUDAGraphRunner(cuda_graph_runner_config) + # Initialize CUDA Graph LoRA manager if LoRA is enabled + self.cuda_graph_lora_manager: Optional[CudaGraphLoraManager] = None + # Setup the local cache indirection buffer only once and reuse it. # This way it can also be used for CUDA graphs. if self.use_beam_search: @@ -493,6 +499,26 @@ class PyTorchModelEngine(ModelEngine): dtype=torch_dtype_to_str(self.model.config.torch_dtype), swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight) + def _init_cuda_graph_lora_manager(self, lora_config: LoraConfig): + """Initialize CUDA Graph LoRA manager with model configuration.""" + # Get model configuration + if self.cuda_graph_runner.enabled: + max_lora_size = lora_config.max_loras or 8 # Default fallback + max_batch_size = self.batch_size # Use engine's max batch size + + self.cuda_graph_lora_manager = CudaGraphLoraManager( + max_lora_size=max_lora_size, + max_batch_size=max_batch_size, + max_lora_rank=lora_config.max_lora_rank, + model=self.model, + lora_model_config=self.lora_model_config, + device='cuda') + + logger.info( + f"Initialized CUDA Graph LoRA manager, " + f"max {max_lora_size} adapters, max rank {lora_config.max_lora_rank}" + ) + def set_guided_decoder(self, guided_decoder: CapturableGuidedDecoder) -> bool: if hasattr(self.model, "set_guided_decoder"): @@ -1936,7 +1962,8 @@ class PyTorchModelEngine(ModelEngine): cache_indirection_buffer: Optional[torch.Tensor] = None, num_accepted_tokens_device: Optional[torch.Tensor] = None, req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None, - resource_manager: Optional[ResourceManager] = None): + resource_manager: Optional[ResourceManager] = None, + maybe_graph: bool = False): """ Prepare inputs for Pytorch Model. """ @@ -2673,8 +2700,10 @@ class PyTorchModelEngine(ModelEngine): attn_metadata.prepare() + peft_cache_manager = resource_manager and resource_manager.get_resource_manager( + ResourceManagerType.PEFT_CACHE_MANAGER) lora_params = self._get_lora_params_from_requests( - scheduled_requests, attn_metadata) + scheduled_requests, attn_metadata, peft_cache_manager, maybe_graph) attn_all_rank_num_tokens = self._get_all_rank_num_tokens(attn_metadata) padded_num_tokens, can_run_piecewise_cuda_graph, attn_all_rank_num_tokens = self._get_padding_params( @@ -3142,10 +3171,41 @@ class PyTorchModelEngine(ModelEngine): 'inputs_embeds': None }, gather_ids if is_spec_decode else None - def _get_lora_params_from_requests(self, - scheduled_requests: ScheduledRequests, - attn_metadata: AttentionMetadata): + def _get_lora_params_from_requests( + self, + scheduled_requests: ScheduledRequests, + attn_metadata: AttentionMetadata, + peft_cache_manager: Optional[PeftCacheManager] = None, + maybe_graph: bool = False): ''' + Get LoRA parameters from scheduled requests. + + Uses CUDA Graph compatible mode in decode only batch, otherwise falls back to eager mode. + + Returns: + Dictionary containing LoRA parameters, or None if no LoRA requests + ''' + use_cuda_graph_mode = self.cuda_graph_lora_manager is not None and maybe_graph + + if use_cuda_graph_mode: + return self.cuda_graph_lora_manager.prepare_cuda_graph_lora_params( + scheduled_requests, attn_metadata, peft_cache_manager) + else: + if self.cuda_graph_lora_manager is not None: + self.cuda_graph_lora_manager.adapter_slot_manager.remove_evicted_slots_in_cpp( + peft_cache_manager) + peft_table = peft_cache_manager.get_and_reset_batch_peft_table( + ) if peft_cache_manager is not None else None + return peft_table and self._get_eager_lora_params_from_requests( + scheduled_requests, attn_metadata, peft_table) + + def _get_eager_lora_params_from_requests( + self, scheduled_requests: ScheduledRequests, + attn_metadata: AttentionMetadata, + peft_table: Dict[int, list[TaskLayerModuleConfig]]): + ''' + Eager mode LoRA parameter preparation logic. + lora_params: dict { layer_id: dict @@ -3165,10 +3225,12 @@ class PyTorchModelEngine(ModelEngine): # trace all requests to get the union set of the lora params for request in request_list: - if request.py_lora_task_layer_module_configs is None: + if request.lora_task_id is None: continue - for module in request.py_lora_task_layer_module_configs: + layer_module_configs = peft_table[request.lora_task_id] + + for module in layer_module_configs: module_id = module.module_id layer_id = module.layer_id @@ -3195,7 +3257,7 @@ class PyTorchModelEngine(ModelEngine): for request in request_list: # Need to set default values for this case - if request.py_lora_task_layer_module_configs is None: + if request.lora_task_id is None: for layer_id in lora_params: for module_id in lora_params[layer_id]: current_lora_params = lora_params[layer_id][module_id] @@ -3245,7 +3307,8 @@ class PyTorchModelEngine(ModelEngine): cache_indirection_buffer: Optional[torch.Tensor] = None, num_accepted_tokens_device: Optional[torch.Tensor] = None, req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None, - resource_manager: Optional[ResourceManager] = None): + resource_manager: Optional[ResourceManager] = None, + maybe_graph: bool = False): if self.mapping is not None and 'cp_type' in self.mapping.cp_config: cp_type = self.mapping.cp_config['cp_type'] if CpType.STAR == cp_type: @@ -3258,12 +3321,11 @@ class PyTorchModelEngine(ModelEngine): raise NotImplementedError( f"Unsupported cp_type {getattr(cp_type, 'name', cp_type)}.") - return self._prepare_tp_inputs(scheduled_requests, kv_cache_manager, - attn_metadata, spec_metadata, - new_tensors_device, - cache_indirection_buffer, - num_accepted_tokens_device, - req_id_to_old_request, resource_manager) + return self._prepare_tp_inputs( + scheduled_requests, kv_cache_manager, attn_metadata, spec_metadata, + new_tensors_device, cache_indirection_buffer, + num_accepted_tokens_device, req_id_to_old_request, resource_manager, + maybe_graph) @torch.inference_mode() @with_model_extra_attrs(lambda self: self.model.extra_attrs) @@ -3347,12 +3409,11 @@ class PyTorchModelEngine(ModelEngine): spec_metadata = self.spec_metadata else: spec_metadata = None - inputs, gather_ids = self._prepare_inputs( padded_requests, kv_cache_manager, attn_metadata, spec_metadata, new_tensors_device, cache_indirection_buffer, num_accepted_tokens_device, req_id_to_old_request, - resource_manager) + resource_manager, can_run_graph) with with_shared_pool(self.cuda_graph_runner.get_graph_pool()): if not can_run_graph: diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 0d739c807c..bfdfb39af7 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -11,6 +11,7 @@ import torch import tensorrt_llm import tensorrt_llm.bindings from tensorrt_llm._torch.distributed.communicator import Distributed, ReduceOp +from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig, PybindMirror) from tensorrt_llm.lora_helper import LoraConfig @@ -1554,6 +1555,9 @@ class PeftCacheManager(BaseResourceManager): model_config=ModelConfigPython.from_model_config_cpp(model_config), cpp_peft_cache_manager=self.impl) + self._batch_peft_table: Optional[Dict[int, list[ + TaskLayerModuleConfig]]] = None # task_id -> layer-module-configs mapping for the current batch + def get_lora_manager(self) -> LoraManager: return self._lora_manager @@ -1600,18 +1604,9 @@ class PeftCacheManager(BaseResourceManager): for req in context_batch: self.add_request_peft(req) - py_lora_task_layer_module_configs = self.impl.ensure_batch( + self._batch_peft_table, _ = self.impl.ensure_batch_map_task_id( context_batch, generation_batch, False) - for req in context_batch: - req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[ - req. - py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None - for req in generation_batch: - req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[ - req. - py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None - def update_resources(self, scheduled_batch: ScheduledRequests): pass @@ -1620,3 +1615,12 @@ class PeftCacheManager(BaseResourceManager): def shutdown(self): pass + + def get_and_reset_batch_peft_table( + self) -> Dict[int, list[TaskLayerModuleConfig]]: + batch_peft_table = self._batch_peft_table + self._batch_peft_table = None + return batch_peft_table + + def is_task_cached_device(self, task_id: int) -> bool: + return self.impl.is_task_cached_device(task_id) diff --git a/tests/unittest/_torch/thop/parallel/test_custom_ops.py b/tests/unittest/_torch/thop/parallel/test_custom_ops.py index f75dc4fec1..5d65b83a75 100644 --- a/tests/unittest/_torch/thop/parallel/test_custom_ops.py +++ b/tests/unittest/_torch/thop/parallel/test_custom_ops.py @@ -59,6 +59,8 @@ def test_register_fake(custom_ops): # TODO: add fake impl for these ops in follow-up PRs. to_fix = { "trtllm::lora_grouped_gemm", + "trtllm::lora_grouped_gemm_cuda_graph", + "trtllm::lora_group_gemm_param_fill_row_reorder_fusion", "trtllm::mtp_relaxed_acceptance_op", "trtllm::mtp_update_hidden_states_op", "trtllm::mtp_prepare_drafter_inputs_op", diff --git a/tests/unittest/llmapi/lora_test_utils.py b/tests/unittest/llmapi/lora_test_utils.py index a123df495b..2c4b978254 100644 --- a/tests/unittest/llmapi/lora_test_utils.py +++ b/tests/unittest/llmapi/lora_test_utils.py @@ -1,18 +1,28 @@ import json import tarfile import tempfile +from dataclasses import asdict, dataclass from pathlib import Path -from typing import List, OrderedDict, Type +from typing import List, Optional, OrderedDict, Tuple, Type, Union +import pytest import torch from utils.llm_data import llm_models_root from utils.util import duplicate_list_to_length, flatten_list, similar from tensorrt_llm import SamplingParams +from tensorrt_llm._torch.peft.lora.cuda_graph_lora_params import \ + CudaGraphLoraParams +from tensorrt_llm._torch.peft.lora.layer import (GroupedGemmParamsInput, + GroupedGemmParamsOutput, + LoraLayer) from tensorrt_llm.executor.request import LoRARequest from tensorrt_llm.llmapi.llm import BaseLLM +from tensorrt_llm.llmapi.llm_args import CudaGraphConfig from tensorrt_llm.lora_helper import LoraConfig +from .test_utils import DelayedAssert + _RU_LORA_ADAPTER_PROMPTS = [ "Назови главную площадь в центре Москвы.", "Напиши полное предложение, описывающее, что в музее не хватает женских скульптур. Используй фразу \"не хватает\".", @@ -284,3 +294,233 @@ def create_mock_nemo_lora_checkpoint( tar.add(config_path, arcname="model_config.yaml") return nemo_path + + +@dataclass +class CUDAGraphLoRATestParams: + batch_slot_ids: List[int] + input_hidden_size: int + slot_ranks: List[int] + max_lora_rank: int + output_hidden_sizes: List[int] + layer_module_mask: Optional[Union[torch.Tensor, bool]] + dtype: torch.dtype + seed: int + + def __post_init__(self): + assert self.layer_module_mask is None or isinstance( + self.layer_module_mask, + bool) or self.layer_module_mask.shape == (self.module_count, + self.slot_count) + assert all(0 <= idx <= self.slot_count for idx in self.batch_slot_ids) + assert all(0 <= rank <= self.max_lora_rank for rank in self.slot_ranks) + if isinstance(self.layer_module_mask, torch.Tensor): + self.layer_module_mask = self.layer_module_mask.to(dtype=torch.bool) + elif self.layer_module_mask is not None: + self.layer_module_mask = bool(self.layer_module_mask) + else: + self.layer_module_mask = True + + @property + def module_count(self): + return len(self.output_hidden_sizes) + + @property + def slot_count(self): + return len(self.slot_ranks) + + @property + def batch_size(self): + return len(self.batch_slot_ids) + + @property + def sum_output_hidden_size(self): + return sum(self.output_hidden_sizes) + + +def create_grouped_gemm_params_filler_input( + test_params: Optional[CUDAGraphLoRATestParams] = None +) -> Tuple[GroupedGemmParamsInput, LoraLayer]: + if test_params is None: + test_params = CUDAGraphLoRATestParams( + batch_slot_ids=[0, 3, 3, 4, 5, 8], + input_hidden_size=4096, + slot_ranks=[8, 12, 4, 3] * 2, + max_lora_rank=64, + output_hidden_sizes=[4096, 4096], + layer_module_mask=None, + dtype=torch.bfloat16, + seed=42, + ) + + with torch.random.fork_rng(): + torch.manual_seed(test_params.seed) + shape_2d = (test_params.module_count, test_params.slot_count) + + x = torch.randn(test_params.batch_size, + test_params.input_hidden_size, + dtype=test_params.dtype, + device="cuda") + output_buffer = torch.randn(test_params.batch_size, + test_params.sum_output_hidden_size, + dtype=test_params.dtype, + device="cuda") + b_ptrs = torch.randint(1, + 1000000, + shape_2d, + dtype=CudaGraphLoraParams.PTR_DTYPE) + b_prime_ptrs = torch.randint(1, + 1000000, + shape_2d, + dtype=CudaGraphLoraParams.PTR_DTYPE) + + b_ptrs *= test_params.layer_module_mask + b_prime_ptrs *= test_params.layer_module_mask + + b_ptrs = b_ptrs.to(device="cuda") + b_prime_ptrs = b_prime_ptrs.to(device="cuda") + slot_ranks = torch.tensor(test_params.slot_ranks, + dtype=CudaGraphLoraParams.SIZES_DTYPE, + device="cuda") + + intermediate_buffer = torch.randn(test_params.module_count, + test_params.batch_size, + test_params.max_lora_rank, + dtype=test_params.dtype, + device="cuda") + slot_counts = CudaGraphLoraParams.get_slot_counts( + test_params.batch_slot_ids, test_params.slot_count) + slot_offsets_full = CudaGraphLoraParams.get_offset_from_counts( + slot_counts, full=True) + sorted_ids = CudaGraphLoraParams.get_sorted_indices( + test_params.batch_slot_ids) + + slot_offsets_full = slot_offsets_full.to(device="cuda", + dtype=torch.int64) + slot_counts = slot_counts.to(device="cuda", dtype=torch.int32) + sorted_ids = sorted_ids.to(device="cuda", dtype=torch.int64) + + output_hidden_sizes = torch.tensor( + test_params.output_hidden_sizes, + dtype=CudaGraphLoraParams.SIZES_DTYPE, + device="cuda") + output_sizes_offset = CudaGraphLoraParams.get_offset_from_counts( + output_hidden_sizes).to(dtype=CudaGraphLoraParams.PTR_DTYPE, + device="cuda") + + layer = LoraLayer([0] * test_params.module_count, + test_params.output_hidden_sizes) + inputs = GroupedGemmParamsInput( + x=x, + output_buffer=output_buffer, + intermediate_buffer=intermediate_buffer, + max_lora_size=test_params.slot_count, + max_rank=test_params.max_lora_rank, + slot_counts=slot_counts, + slot_ranks=slot_ranks, + slot_offsets_full=slot_offsets_full, + sorted_ids=sorted_ids, + b_ptrs=b_ptrs, + b_prime_ptrs=b_prime_ptrs, + output_hidden_sizes=output_hidden_sizes, + output_sizes_offset=output_sizes_offset, + ) + return inputs, layer + + +def compare_grouped_gemm_params( + params: GroupedGemmParamsOutput, + ref: GroupedGemmParamsOutput, + params_input: GroupedGemmParamsInput, + params_to_store_msg: List[str] | None = ['splitk_offsets'], + params_exclude_msg: List[str] | None = None, +): + assert not (params_to_store_msg and params_exclude_msg) + + bs, input_hidden_size = params.reordered_input.shape + asserter = DelayedAssert() + params_dict = asdict(params) + ref_dict = asdict(ref) + + if not params_to_store_msg: + params_to_store_msg = set(params_dict.keys()) + if params_exclude_msg: + for name in params_exclude_msg: + params_to_store_msg.discard(name) + + def get_msg(name: str, v: torch.Tensor, ref_v: torch.Tensor): + is_get_msg = any(p in name or name in p for p in params_to_store_msg) + header = f"\n\n{name=}\n" + return f"{header} {v=}\n {ref_v=}\n diff:\n{v - ref_v}" if is_get_msg else header + + for name in params_dict.keys(): + v = params_dict[name] + ref_v = ref_dict[name] + if name not in ("reordered_input", "a_offset"): + asserter.add( + v.allclose(ref_v), + get_msg(name, v, ref_v), + ) + + # Test a_offset separately + offset = params.a_offset - params.reordered_input.data_ptr() + ref_offset = ref.a_offset - ref.reordered_input.data_ptr() + asserter.add( + (offset == ref_offset).all(), + # 'a_offset_fused', + get_msg("a_offset", offset, ref_offset)) + + # Test reordered_input separately + valid_row = params_input.slot_offsets_full[-1].cpu().item() + valid_rows = params.reordered_input[:valid_row] + ref_valid_rows = ref.reordered_input[:valid_row] + asserter.add( + valid_rows.allclose(ref_valid_rows), + get_msg(f"valid part({valid_row=}, {bs=}) of reordered_input", + valid_rows, ref_valid_rows)) + + # check intermediate buffer and output buffer are all zeros + asserter.add( + torch.all(params_input.intermediate_buffer == 0), + get_msg("intermediate buffer", params_input.intermediate_buffer, 0)) + asserter.add(torch.all(params_input.output_buffer == 0), + get_msg("output buffer", params_input.output_buffer, 0)) + + if valid_row < bs: + invalid_rows = params.reordered_input[valid_row:] + ref_invalid_rows = ref.reordered_input[valid_row:] + asserter.add( + torch.all(invalid_rows == 0), + get_msg("invalid part of reordered_input", invalid_rows, + ref_invalid_rows)) + else: + asserter.add( + True, + f"valid_row is full {valid_row=} v. bs: {params_dict['reordered_input'].shape[0]=}" + ) + asserter.assert_all() + + +def compare_cuda_graph_lora_params_filler(test_params: CUDAGraphLoRATestParams): + grouped_gemm_params_filler_input, layer = create_grouped_gemm_params_filler_input( + test_params) + output_fused = layer._prepare_grouped_gemm_buffers_fused( + grouped_gemm_params_filler_input) + + assert torch.all( + grouped_gemm_params_filler_input.intermediate_buffer == 0 + ), f"intermediate_buffer is not all zeros: {grouped_gemm_params_filler_input.intermediate_buffer}; non zero / zeros: {(grouped_gemm_params_filler_input.intermediate_buffer != 0).sum()} / {grouped_gemm_params_filler_input.intermediate_buffer.numel()}" + assert torch.all( + grouped_gemm_params_filler_input.output_buffer == 0 + ), f"output_buffer is not all zeros: {grouped_gemm_params_filler_input.output_buffer}; non zero / zeros: {(grouped_gemm_params_filler_input.output_buffer != 0).sum()} / {grouped_gemm_params_filler_input.output_buffer.numel()}" + + output_pytorch = layer.prepare_grouped_gemm_buffers( + grouped_gemm_params_filler_input) + compare_grouped_gemm_params(output_fused, + output_pytorch, + grouped_gemm_params_filler_input, + params_to_store_msg=None) + + +test_lora_with_and_without_cuda_graph = pytest.mark.parametrize( + "cuda_graph_config", [CudaGraphConfig(max_batch_size=10), None]) diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index 68cdc62ba7..2229de60e8 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -9,7 +9,8 @@ from tensorrt_llm.sampling_params import SamplingParams from .lora_test_utils import ( check_llama_7b_multi_lora_from_request_test_harness, - check_phi3_lora_fused_modules_output_tp2_identical_to_tp1) + check_phi3_lora_fused_modules_output_tp2_identical_to_tp1, + test_lora_with_and_without_cuda_graph) from .test_llm import (_test_llm_capture_request_error, llama_model_path, llm_get_stats_test_harness, llm_return_logprobs_test_harness, @@ -46,9 +47,10 @@ def test_llama_7b_lora_tp2(): kv_cache_config=global_kv_cache_config) -@pytest.mark.gpu2 -@pytest.mark.skip(reason="https://nvbugs/5682551") -def test_llama_7b_multi_lora_tp2(): +@pytest.mark.gpu4 +@skip_ray # https://nvbugs/5682551 +@test_lora_with_and_without_cuda_graph +def test_llama_7b_multi_lora_tp4(cuda_graph_config): # For LoRA checkpoints without finetuned embedding and lm_head, we can either: # (1) specify lora_target_modules, or # (2) provide a lora_dir to infer the lora_target_modules. @@ -59,21 +61,18 @@ def test_llama_7b_multi_lora_tp2(): check_llama_7b_multi_lora_from_request_test_harness( LLM, lora_config=lora_config, - tensor_parallel_size=2, + tensor_parallel_size=4, kv_cache_config=global_kv_cache_config, - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - cuda_graph_config=None) + cuda_graph_config=cuda_graph_config) @skip_ray # https://nvbugs/5727075 @pytest.mark.gpu2 -def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None: +@test_lora_with_and_without_cuda_graph +def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1( + cuda_graph_config) -> None: check_phi3_lora_fused_modules_output_tp2_identical_to_tp1( - LLM, - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - cuda_graph_config=None) + LLM, cuda_graph_config=cuda_graph_config) @skip_ray diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 86f48d3126..e5ee03da65 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -20,7 +20,8 @@ from tensorrt_llm.sampling_params import SamplingParams from .lora_test_utils import ( check_llama_7b_multi_lora_from_request_test_harness, check_llama_7b_multi_unique_lora_adapters_from_request, - create_mock_nemo_lora_checkpoint) + create_mock_nemo_lora_checkpoint, compare_cuda_graph_lora_params_filler, + CUDAGraphLoRATestParams, test_lora_with_and_without_cuda_graph) from .test_llm import (_test_llm_capture_request_error, get_model_path, global_kvcache_config, llama_model_path, llm_get_stats_async_test_harness, @@ -42,6 +43,7 @@ import torch from peft import LoraConfig as PeftLoraConfig from peft import get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer +from dataclasses import replace # isort: on @@ -283,19 +285,76 @@ def test_embedding_bias_with_torch_sampler_strategies(): ) +def test_lora_cuda_graph_params_filling_kernel_special_cases(): + torch.cuda.set_device(0) + + # test all requests have the same LoRA id case + test_params = CUDAGraphLoRATestParams( + batch_slot_ids=[0] * 10, + input_hidden_size=4096, + slot_ranks=[64] * 10, + max_lora_rank=64, + output_hidden_sizes=[123], + layer_module_mask=None, + dtype=torch.bfloat16, + seed=42, + ) + compare_cuda_graph_lora_params_filler(test_params) + + # test no LoRA in a batch case + test_params2 = replace(test_params, + batch_slot_ids=[len(test_params.slot_ranks)] * 10) + compare_cuda_graph_lora_params_filler(test_params2) + + # test all having three modules case + test_params3 = replace(test_params, output_hidden_sizes=[123, 456, 789]) + compare_cuda_graph_lora_params_filler(test_params3) + + # test some layer module have invalid weight pointers case + mask = torch.full((test_params3.module_count, test_params3.slot_count), + True, + dtype=torch.bool) + mask[0, 0] = False + mask[1, 7] = False + mask[2, 3] = False + test_params4 = replace(test_params3, layer_module_mask=mask) + compare_cuda_graph_lora_params_filler(test_params4) + + # test mixed slot ids case + test_params5 = CUDAGraphLoRATestParams( + batch_slot_ids=[6, 2, 0, 1, 1, 1, 5, 6], + input_hidden_size=512, + slot_ranks=[8, 12, 4] * 2, + max_lora_rank=15, + output_hidden_sizes=[123, 456, 789], + layer_module_mask=None, + dtype=torch.bfloat16, + seed=42, + ) + compare_cuda_graph_lora_params_filler(test_params5) + + # test mixed slot with invalid weight pointers + mask = torch.full((test_params5.module_count, test_params5.slot_count), + True, + dtype=torch.bool) + mask[1, 3] = False + mask[2, 5] = False + mask[-1, -4] = False + test_params6 = replace(test_params5, layer_module_mask=mask) + compare_cuda_graph_lora_params_filler(test_params6) + + def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None: lora_config = LoraConfig( lora_dir=[f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"], max_lora_rank=8, max_loras=2, max_cpu_loras=2) - llm = LLM( - model=f"{llm_models_root()}/llama-models/llama-7b-hf", - lora_config=lora_config, - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - cuda_graph_config=None, - **llm_kwargs) + if "cuda_graph_config" not in llm_kwargs: + llm_kwargs["cuda_graph_config"] = None + llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf", + lora_config=lora_config, + **llm_kwargs) try: prompts = [ "美国的首都在哪里? \n答案:", @@ -318,22 +377,21 @@ def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None: @skip_gpu_memory_less_than_40gb @pytest.mark.part0 -def test_llama_7b_lora(): - llama_7b_lora_from_dir_test_harness() +@test_lora_with_and_without_cuda_graph +def test_llama_7b_lora(cuda_graph_config): + llama_7b_lora_from_dir_test_harness(cuda_graph_config=cuda_graph_config) @skip_gpu_memory_less_than_40gb -def test_llama_7b_lora_default_modules() -> None: +@test_lora_with_and_without_cuda_graph +def test_llama_7b_lora_default_modules(cuda_graph_config) -> None: lora_config = LoraConfig(max_lora_rank=64, max_loras=2, max_cpu_loras=2) hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf" - llm = LLM( - model=hf_model_dir, - lora_config=lora_config, - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - cuda_graph_config=None) + llm = LLM(model=hf_model_dir, + lora_config=lora_config, + cuda_graph_config=cuda_graph_config) hf_lora_dir = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1" try: @@ -359,7 +417,8 @@ def test_llama_7b_lora_default_modules() -> None: def _check_llama_7b_multi_lora_evict_load_new_adapters( lora_adapter_count_per_call: list[int], max_loras: int, - max_cpu_loras: int, repeat_calls: int, repeats_per_call: int): + max_cpu_loras: int, repeat_calls: int, repeats_per_call: int, + **llm_kwargs): # For LoRA checkpoints without finetuned embedding and lm_head, we can either: # (1) specify lora_target_modules, or # (2) provide a lora_dir to infer the lora_target_modules. @@ -373,15 +432,14 @@ def _check_llama_7b_multi_lora_evict_load_new_adapters( repeats_per_call, LLM, lora_config=lora_config, - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - cuda_graph_config=None) + **llm_kwargs) @skip_gpu_memory_less_than_40gb @skip_ray # https://nvbugs/5682551 @pytest.mark.part3 -def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache(): +@test_lora_with_and_without_cuda_graph +def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache(cuda_graph_config): """Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single llm.generate call, that's repeated twice. """ # noqa: D205 @@ -390,12 +448,15 @@ def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache(): max_loras=1, max_cpu_loras=2, repeat_calls=2, - repeats_per_call=3) + repeats_per_call=3, + cuda_graph_config=cuda_graph_config) @skip_gpu_memory_less_than_40gb @pytest.mark.part1 -def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache(): +@test_lora_with_and_without_cuda_graph +def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache( + cuda_graph_config): """Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU cache size < LoRA CPU cache size. """ # noqa: D205 @@ -404,25 +465,29 @@ def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache(): max_loras=1, max_cpu_loras=3, repeat_calls=1, - repeats_per_call=1) + repeats_per_call=1, + cuda_graph_config=cuda_graph_config) @skip_gpu_memory_less_than_40gb @pytest.mark.part0 -def test_llama_7b_multi_lora_read_from_cache_after_insert(): +@test_lora_with_and_without_cuda_graph +def test_llama_7b_multi_lora_read_from_cache_after_insert(cuda_graph_config): """Test that loading and then using the same adapters loaded in cache works.""" _check_llama_7b_multi_lora_evict_load_new_adapters( lora_adapter_count_per_call=[3], max_loras=3, max_cpu_loras=3, repeat_calls=2, - repeats_per_call=1) + repeats_per_call=1, + cuda_graph_config=cuda_graph_config) @skip_gpu_memory_less_than_40gb @pytest.mark.part3 +@test_lora_with_and_without_cuda_graph def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_cache( -): + cuda_graph_config): """Test eviction, reloading new adapters and reloading previously evicted adapters from the LoRA CPU cache & GPU cache over multiple llm.generate call repeated twice (two calls with the same requests): At the end of the 1st llm.generate call: @@ -439,12 +504,14 @@ def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_ca max_loras=2, max_cpu_loras=2, repeat_calls=2, - repeats_per_call=1) + repeats_per_call=1, + cuda_graph_config=cuda_graph_config) @skip_gpu_memory_less_than_40gb @pytest.mark.part2 -def test_llama_7b_peft_cache_config_affects_peft_cache_size(): +@test_lora_with_and_without_cuda_graph +def test_llama_7b_peft_cache_config_affects_peft_cache_size(cuda_graph_config): """Tests that LLM arg of peft_cache_config affects the peft cache sizes. NOTE: The caller can't get the actual LoRA cache sizes, so we instead we @@ -464,9 +531,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): lora_config=lora_config_no_cache_size_values, peft_cache_config=PeftCacheConfig( host_cache_size=1), # size in bytes - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - cuda_graph_config=None) + cuda_graph_config=cuda_graph_config) # Test that too small PeftCacheConfig.device_cache_percent causes failure with pytest.raises(RuntimeError): @@ -474,15 +539,14 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size(): LLM, lora_config=lora_config_no_cache_size_values, peft_cache_config=PeftCacheConfig(device_cache_percent=0.0000001), - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - cuda_graph_config=None) + cuda_graph_config=cuda_graph_config) @skip_ray # https://nvbugs/5682551 @skip_gpu_memory_less_than_40gb @pytest.mark.part1 -def test_llama_7b_lora_config_overrides_peft_cache_config(): +@test_lora_with_and_without_cuda_graph +def test_llama_7b_lora_config_overrides_peft_cache_config(cuda_graph_config): """Tests that cache size args in lora_config LLM arg override the cache size parameters in peft_cache_config LLM arg. """ # noqa: D205 @@ -496,9 +560,7 @@ def test_llama_7b_lora_config_overrides_peft_cache_config(): peft_cache_config=PeftCacheConfig( host_cache_size=1, # size in bytes device_cache_percent=0.0000001), - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - cuda_graph_config=None) + cuda_graph_config=cuda_graph_config) # TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high @@ -506,7 +568,8 @@ def test_llama_7b_lora_config_overrides_peft_cache_config(): @pytest.mark.skip(reason="https://nvbugs/5448464") @skip_gpu_memory_less_than_138gb @pytest.mark.part1 -def test_nemotron_nas_lora() -> None: +@test_lora_with_and_without_cuda_graph +def test_nemotron_nas_lora(cuda_graph_config) -> None: lora_config = LoraConfig(lora_dir=[ f"{llm_models_root()}/nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1-lora-adapter_r64" ], @@ -518,7 +581,7 @@ def test_nemotron_nas_lora() -> None: model= f"{llm_models_root()}/nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1", lora_config=lora_config, - ) + cuda_graph_config=cuda_graph_config) prompts = [ "Hello, how are you?", @@ -539,7 +602,8 @@ def test_nemotron_nas_lora() -> None: @skip_gpu_memory_less_than_80gb @pytest.mark.part0 -def test_llama_3_1_8b_fp8_with_bf16_lora() -> None: +@test_lora_with_and_without_cuda_graph +def test_llama_3_1_8b_fp8_with_bf16_lora(cuda_graph_config) -> None: skip_fp8_pre_ada(use_fp8=True) model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8" lora_dir = f"{llm_models_root()}/lora/llama-3-chinese-8b-instruct-v2-lora" @@ -552,12 +616,9 @@ def test_llama_3_1_8b_fp8_with_bf16_lora() -> None: max_cpu_loras=2) lora_req = LoRARequest("lora-chinese", 0, lora_dir) - llm = LLM( - model_dir, - lora_config=lora_config, - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - cuda_graph_config=None) + llm = LLM(model_dir, + lora_config=lora_config, + cuda_graph_config=cuda_graph_config) try: output = llm.generate(prompt, @@ -570,7 +631,8 @@ def test_llama_3_1_8b_fp8_with_bf16_lora() -> None: @skip_gpu_memory_less_than_80gb @pytest.mark.part2 -def test_bielik_11b_v2_2_instruct_multi_lora() -> None: +@test_lora_with_and_without_cuda_graph +def test_bielik_11b_v2_2_instruct_multi_lora(cuda_graph_config) -> None: model_dir = f"{llm_models_root()}/Bielik-11B-v2.2-Instruct" target_modules = ['attn_q', 'attn_k', 'attn_v'] @@ -600,12 +662,9 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None: max_lora_rank=8, max_loras=2, max_cpu_loras=2) - llm = LLM( - model_dir, - lora_config=trtllm_lora_config, - # Disable CUDA graph - # TODO: remove this once we have a proper fix for CUDA graph in LoRA - cuda_graph_config=None) + llm = LLM(model_dir, + lora_config=trtllm_lora_config, + cuda_graph_config=cuda_graph_config) prompts = [ "Kim był Mikołaj Kopernik i z czego zasłynął?", @@ -624,7 +683,8 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None: @pytest.mark.part2 -def test_gemma3_1b_instruct_multi_lora() -> None: +@test_lora_with_and_without_cuda_graph +def test_gemma3_1b_instruct_multi_lora(cuda_graph_config) -> None: model_dir = f"{llm_models_root()}/gemma/gemma-3-1b-it" target_modules = ['attn_q', 'attn_k', 'attn_v'] @@ -662,7 +722,8 @@ def test_gemma3_1b_instruct_multi_lora() -> None: ) llm = LLM(model_dir, lora_config=trtllm_lora_config, - kv_cache_config=kv_cache_config) + kv_cache_config=kv_cache_config, + cuda_graph_config=cuda_graph_config) prompts = [ "Is it ok to fill diesel in a petrol car?", @@ -745,7 +806,8 @@ def test_nemo_lora_unsupported_modules_validation(tmp_path): @force_ampere @pytest.mark.part1 -def test_gqa_nemo_lora(tmp_path): +@test_lora_with_and_without_cuda_graph +def test_gqa_nemo_lora(tmp_path, cuda_graph_config): """ Test NeMo-format LoRA checkpoint loading and GQA support in TinyLlama. @@ -790,6 +852,7 @@ def test_gqa_nemo_lora(tmp_path): model=model_path, lora_config=lora_config, kv_cache_config=global_kvcache_config, + cuda_graph_config=cuda_graph_config, ) try: diff --git a/tests/unittest/llmapi/test_utils.py b/tests/unittest/llmapi/test_utils.py index 5488f7c7ba..24718b064d 100644 --- a/tests/unittest/llmapi/test_utils.py +++ b/tests/unittest/llmapi/test_utils.py @@ -32,3 +32,33 @@ def test_generate_api_docs_as_docstring(): doc = generate_api_docs_as_docstring(LlmArgs) assert ":tag:`beta`" in doc, "the label is not generated" print(doc) + + +class DelayedAssert: + + def __init__(self, store_stack: bool = False): + self.assertions = [] + self.store_stack = store_stack + + def add(self, result: bool, msg: str): + import traceback + self.assertions.append( + (bool(result), str(msg), traceback.format_stack())) + + def get_msg(self): + ret = ['Some assertions failed:'] + for result, msg, stack in self.assertions: + ret.append('\n'.join([ + f'Assert result: {result}', msg, + ''.join(stack) if self.store_stack else '' + ])) + ret = '\n-----------------------------------------\n'.join(ret) + ret = 'Some assertions failed:\n' + ret + return ret + + def clear(self): + self.assertions.clear() + + def assert_all(self): + assert all(ret[0] for ret in self.assertions), self.get_msg() + self.clear()