[https://nvbugs/5322131][feat] Multi-LoRA serving with CUDA Graph (#8279)

Signed-off-by: Jiayu Chang <jiayuc@nvidia.com>
This commit is contained in:
Jiayu Chang 2026-01-22 21:01:18 +08:00 committed by GitHub
parent cdb9ffd0ab
commit 1dc49b266e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 2766 additions and 172 deletions

View File

@ -57,7 +57,10 @@ class BasePeftCacheManager
public:
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
using RequestVector = std::vector<LlmRequestPtr>;
using PeftTable = std::map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
using PeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
using TaskPeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
using TaskIdToReqIds = std::unordered_map<uint64_t, std::vector<uint64_t>>;
using EnsureBatchTaskResult = std::tuple<TaskPeftTable, TaskIdToReqIds>;
virtual ~BasePeftCacheManager() = default;
@ -99,6 +102,8 @@ public:
class PeftCacheManager : public BasePeftCacheManager
{
public:
using EnsureBatchTaskResult = BasePeftCacheManager::EnsureBatchTaskResult;
PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
@ -109,12 +114,17 @@ public:
PeftTable ensureBatch(RequestVector const& contextRequests, RequestVector const& generationRequests,
bool resetGpuCache = false) override;
EnsureBatchTaskResult ensureBatchMapTaskId(
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache = false);
[[nodiscard]] bool isTaskCached(uint64_t taskId) const;
[[nodiscard]] bool isTaskDone(uint64_t taskId) const;
[[nodiscard]] bool isTaskDoneDevice(uint64_t taskId) const;
[[nodiscard]] bool isTaskCachedDevice(uint64_t const taskId) const;
void resetDeviceCache() override;
void markRequestDone(LlmRequest const& llmReq, bool pause = false) override;
@ -159,7 +169,7 @@ private:
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToReqIds;
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToPausedReqIds;
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>> getTaskMaps(
std::tuple<std::unordered_map<uint64_t, std::future<void>>, TaskIdToReqIds> getTaskMaps(
RequestVector const& contextRequests, RequestVector const& generationRequests);
runtime::ModelConfig mModelConfig;

View File

@ -373,11 +373,11 @@ void PeftCacheManager::addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bo
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>>
std::tuple<std::unordered_map<uint64_t, std::future<void>>, BasePeftCacheManager::TaskIdToReqIds>
PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVector const& generationRequests)
{
std::map<uint64_t, std::vector<uint64_t>> taskIdToReqIds;
std::map<uint64_t, std::future<void>> taskIdToFuture;
TaskIdToReqIds taskIdToReqIds;
std::unordered_map<uint64_t, std::future<void>> taskIdToFuture;
std::lock_guard<std::mutex> futuresLock(mPutFuturesMutex);
for (auto const& requests : {contextRequests, generationRequests})
{
@ -415,7 +415,7 @@ PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVecto
return {std::move(taskIdToFuture), taskIdToReqIds};
}
PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
PeftCacheManager::EnsureBatchTaskResult PeftCacheManager::ensureBatchMapTaskId(
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
@ -426,7 +426,7 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
auto [taskIdToFuture_, taskIdToReqIds] = getTaskMaps(contextRequests, generationRequests);
auto taskIdToFuture = std::move(taskIdToFuture_); // captured structured bindings are a C++20 extension
std::map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
std::unordered_map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
for (auto& [taskId, taskFuture] : taskIdToFuture)
{
auto fn = [&taskIdToFuture, taskId = taskId, this]() -> std::vector<runtime::LoraCache::TaskLayerModuleConfig>
@ -457,18 +457,31 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
ensureFutures.try_emplace(taskId, std::move(f));
}
PeftTable peftTable{};
TaskPeftTable peftTable{};
for (auto const& [taskId, reqIds] : taskIdToReqIds)
{
auto&& f = ensureFutures.at(taskId);
auto const values = f.get();
for (auto const& reqId : reqIds)
{
peftTable.try_emplace(reqId, values);
}
peftTable.try_emplace(taskId, values);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
return peftTable;
return {std::move(peftTable), std::move(taskIdToReqIds)};
}
PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
{
auto [taskTable, taskIdToReqIds] = ensureBatchMapTaskId(contextRequests, generationRequests, resetGpuCache);
PeftTable requestTable{};
for (auto const& [taskId, values] : taskTable)
{
auto const& reqIds = taskIdToReqIds.at(taskId);
for (auto const reqId : reqIds)
{
requestTable.try_emplace(reqId, values);
}
}
return requestTable;
}
bool PeftCacheManager::isTaskCached(uint64_t taskId) const
@ -486,6 +499,11 @@ bool PeftCacheManager::isTaskDoneDevice(uint64_t taskId) const
return mDeviceLoraCache->isDone(taskId);
}
bool PeftCacheManager::isTaskCachedDevice(uint64_t const taskId) const
{
return mDeviceLoraCache->has(taskId);
}
void PeftCacheManager::updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate, bool pause)
{
if (!terminate)
@ -645,3 +663,5 @@ SizeType32 NoOpPeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> l
return 0;
}
} // namespace tensorrt_llm::batch_manager
// TODO: merge C++ LoRA caching status with Py Slot manager

View File

@ -0,0 +1,372 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cuda_graph_grouped_gemm.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <ATen/ATen.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h"
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h"
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
/**
* Template for CUDA Graph compatible grouped GEMM that directly uses GPU tensors
*/
template <int M1, int N1, int K1, int M2, int N2, int K2, typename cutlassType, int kAlignmentAB, int kAlignmentC,
int kStages>
void cudaGraphGroupedGemmTemplate(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream)
{
using ElementA = cutlassType;
using ElementB = cutlassType;
using ElementOutput = cutlassType;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<ElementA, LayoutA,
cutlass::ComplexTransform::kNone, kAlignmentAB, ElementB, LayoutB, cutlass::ComplexTransform::kNone,
kAlignmentAB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<M1, N1, K1>, cutlass::gemm::GemmShape<M2, N2, K2>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<ElementOutput, kAlignmentC, ElementAccumulator,
ElementAccumulator>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, kStages,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>::GemmKernel;
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
float alpha = 1.0f;
float beta = 0.0f;
typename Gemm::EpilogueOutputOp::Params epilogueOp(alpha, beta);
auto ptrA = reinterpret_cast<ElementA**>(ptrAGpu);
auto ptrB = reinterpret_cast<ElementB**>(ptrBGpu);
auto ptrC = reinterpret_cast<ElementOutput**>(ptrCGpu);
auto ptrD = reinterpret_cast<ElementOutput**>(ptrDGpu);
Gemm gemmOp;
int threadblockCount = Gemm::sufficient(nullptr, problemCount);
typename Gemm::Arguments args(problemSizesPtr, // GPU problem sizes
problemCount, // Problem count
threadblockCount, // Threadblock count
epilogueOp, // Epilogue operation
ptrA, // GPU pointer array A
ptrB, // GPU pointer array B
ptrC, // GPU pointer array C (can be nullptr)
ptrD, // GPU pointer array D
ldaGpu, // Precomputed leading dimension A (on GPU)
ldbGpu, // Precomputed leading dimension B (on GPU)
ldcGpu, // Precomputed leading dimension C (on GPU)
lddGpu, // Precomputed leading dimension D (on GPU)
hostMaxProblemSizesPtr);
static_assert(Gemm::BaseKernel::ProblemVisitor::kRequiresPrecomputation == false,
"Grouped GEMM with CUDA Graph cannot use precompution.");
{
cutlass::Status status = gemmOp.can_implement(args);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
"Grouped GEMM cannot be implemented with the given arguments, Error: %s",
cutlass::cutlassGetStatusString(status));
}
at::Tensor workspace;
void* gemmWorkspace = nullptr;
size_t const requiredWorkspace = gemmOp.get_workspace_size(args);
if (requiredWorkspace > 0)
{
auto const workspaceTensorOptions = at::TensorOptions().dtype(at::kByte).device(at::kCUDA);
workspace = at::empty({static_cast<int64_t>(requiredWorkspace)}, workspaceTensorOptions);
gemmWorkspace = workspace.data_ptr();
}
cutlass::Status status = gemmOp.initialize(args, gemmWorkspace);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to initialize grouped GEMM");
status = gemmOp.run(stream);
sync_check_cuda_error(stream);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to execute grouped GEMM");
}
template <int M1, int N1, int K1, int M2, int N2, int K2, int kAlignmentAB, int kAlignmentC, int kStages>
void cudaGraphGroupedGemmType(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
nvinfer1::DataType dataType, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream)
{
if (dataType == nvinfer1::DataType::kHALF)
{
cudaGraphGroupedGemmTemplate<M1, N1, K1, M2, N2, K2, cutlass::half_t, kAlignmentAB, kAlignmentC, kStages>(
problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu,
hostMaxProblemSizesPtr, stream);
}
#ifdef ENABLE_BF16
else if (dataType == nvinfer1::DataType::kBF16)
{
cudaGraphGroupedGemmTemplate<M1, N1, K1, M2, N2, K2, cutlass::bfloat16_t, kAlignmentAB, kAlignmentC, kStages>(
problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu,
hostMaxProblemSizesPtr, stream);
}
#endif
else
{
TLLM_CHECK_WITH_INFO(false, "Unsupported data type for CUDA Graph grouped GEMM");
}
}
void cudaGraphGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, void** ptrBGpu,
void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, bool isLoraIn,
nvinfer1::DataType dataType, int minKN, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream)
{
if (isLoraIn)
{
if (minKN >= 8)
{
cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
}
else if (minKN >= 4)
{
cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 4, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
}
else if (minKN >= 2)
{
cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 2, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
}
else if (minKN >= 1)
{
cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 1, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
}
}
else
{
if (minKN >= 8)
{
cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
}
else if (minKN >= 4)
{
cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 4, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
}
else if (minKN >= 2)
{
cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 2, 8, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
}
else
{
cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 1, 8, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
}
}
}
/**
* Template for CUDA Graph compatible split-K grouped GEMM
*/
template <int M1, int N1, int K1, int M2, int N2, int K2, typename cutlassType, int kAlignmentAB, int kAlignmentC,
int kStages>
void cudaGraphSplitKGroupedGemmTemplate(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
int splitKSlices, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, int64_t* splitKOffsetsGpu, cudaStream_t stream)
{
using ElementA = cutlassType;
using ElementB = cutlassType;
using ElementOutput = cutlassType;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using GemmKernel = typename cutlass::gemm::kernel::DefaultSplitkGemmGrouped<ElementA, LayoutA,
cutlass::ComplexTransform::kNone, kAlignmentAB, ElementB, LayoutB, cutlass::ComplexTransform::kNone,
kAlignmentAB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<M1, N1, K1>, cutlass::gemm::GemmShape<M2, N2, K2>, cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<ElementOutput, kAlignmentC, ElementAccumulator,
ElementAccumulator>,
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, kStages,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>::GemmKernel;
using Gemm = cutlass::gemm::device::SplitkGemmGrouped<GemmKernel>;
float alpha = 1.0f;
float beta = 0.0f;
typename Gemm::EpilogueOutputOp::Params epilogueOp(alpha, beta);
auto ptrA = reinterpret_cast<ElementA**>(ptrAGpu);
auto ptrB = reinterpret_cast<ElementB**>(ptrBGpu);
auto ptrC = reinterpret_cast<ElementOutput**>(ptrCGpu);
auto ptrD = reinterpret_cast<ElementOutput**>(ptrDGpu);
Gemm gemmOp;
int threadblockCount = Gemm::sufficient(nullptr, problemCount);
// Setup arguments for split-K grouped GEMM - using precomputed leading dimensions from GPU tensors
typename Gemm::Arguments args(problemSizesPtr, // GPU problem sizes
problemCount, // Problem count
threadblockCount, // Threadblock count
epilogueOp, // Epilogue operation
ptrA, // GPU pointer array A
ptrB, // GPU pointer array B
ptrC, // GPU pointer array C
ptrD, // GPU pointer array D
ldaGpu, // Precomputed leading dimension A (on GPU)
ldbGpu, // Precomputed leading dimension B (on GPU)
ldcGpu, // Precomputed leading dimension C (on GPU)
lddGpu, // Precomputed leading dimension D (on GPU)
hostMaxProblemSizesPtr, // Host problem sizes
splitKSlices, // Split-K factor
splitKOffsetsGpu);
{
cutlass::Status status = gemmOp.can_implement(args);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
"Split-K grouped GEMM cannot be implemented with the given arguments. Problem count: %d, Split-K slices: "
"%d, Error: %s",
problemCount, splitKSlices, cutlass::cutlassGetStatusString(status));
}
at::Tensor workspace;
void* gemmWorkspace = nullptr;
size_t const requiredWorkspace = gemmOp.get_workspace_size(args);
if (requiredWorkspace > 0)
{
workspace = at::empty(
{static_cast<int64_t>(requiredWorkspace)}, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
gemmWorkspace = workspace.data_ptr();
}
cutlass::Status status = gemmOp.initialize(args, gemmWorkspace);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
"Failed to initialize split-K grouped GEMM. Problem count: %d, Split-K slices: %d, Error: %s", problemCount,
splitKSlices, cutlass::cutlassGetStatusString(status));
status = gemmOp.run(stream);
sync_check_cuda_error(stream);
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
"Failed to execute split-K grouped GEMM. Problem count: %d, Split-K slices: %d, Error: %s", problemCount,
splitKSlices, cutlass::cutlassGetStatusString(status));
}
template <int M1, int N1, int K1, int M2, int N2, int K2, int kAlignmentAB, int kAlignmentC, int kStages>
void cudaGraphSplitKGroupedGemmType(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
nvinfer1::DataType dataType, int splitKSlices, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr,
int64_t* splitKOffsetsGpu, cudaStream_t stream)
{
if (dataType == nvinfer1::DataType::kHALF)
{
cudaGraphSplitKGroupedGemmTemplate<M1, N1, K1, M2, N2, K2, cutlass::half_t, kAlignmentAB, kAlignmentC, kStages>(
problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu,
splitKSlices, hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
}
#ifdef ENABLE_BF16
else if (dataType == nvinfer1::DataType::kBF16)
{
cudaGraphSplitKGroupedGemmTemplate<M1, N1, K1, M2, N2, K2, cutlass::bfloat16_t, kAlignmentAB, kAlignmentC,
kStages>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu,
splitKSlices, hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
}
#endif
else
{
TLLM_CHECK_WITH_INFO(false, "Unsupported data type for CUDA Graph split-K grouped GEMM");
}
}
void cudaGraphSplitKGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
bool isLoraIn, nvinfer1::DataType dataType, int splitKSlices, int minKN,
cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, int64_t* splitKOffsetsGpu, cudaStream_t stream)
{
if (isLoraIn)
{
if (minKN >= 8)
{
cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu,
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
}
else if (minKN >= 4)
{
cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 4, 4>(problemSizesPtr, problemCount, ptrAGpu,
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
}
else if (minKN >= 2)
{
cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 2, 2>(problemSizesPtr, problemCount, ptrAGpu,
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
}
else if (minKN >= 1)
{
cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 1, 2>(problemSizesPtr, problemCount, ptrAGpu,
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
}
}
else
{
if (minKN >= 8)
{
cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu,
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
}
else if (minKN >= 4)
{
cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 4, 8, 4>(problemSizesPtr, problemCount, ptrAGpu,
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
}
else if (minKN >= 2)
{
cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 2, 8, 2>(problemSizesPtr, problemCount, ptrAGpu,
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
}
else
{
cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 1, 8, 2>(problemSizesPtr, problemCount, ptrAGpu,
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
}
}
}
} // namespace kernels
TRTLLM_NAMESPACE_END

View File

@ -0,0 +1,63 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/gemm_coord.h"
#include "tensorrt_llm/common/config.h"
#include <NvInferRuntime.h>
#include <cuda_runtime.h>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
/**
* @brief CUDA Graph compatible wrapper for grouped GEMM operations.
*
* This function accepts GPU pointers directly without any workspace for parameters,
* making it fully compatible with CUDA Graph capture and replay.
*
* @param problemSizesPtr GPU pointer to array of cutlass::gemm::GemmCoord
* @param problemCount Number of GEMM problems
* @param ptrAGpu GPU pointer to array of A matrix pointers
* @param ptrBGpu GPU pointer to array of B matrix pointers
* @param ptrCGpu GPU pointer to array of C matrix pointers (can be nullptr)
* @param ptrDGpu GPU pointer to array of D matrix pointers
* @param isLoraIn Whether this is for LoRA input transformation
* @param dataType Data type of the matrices
* @param minKN Minimum K*N value for kernel selection
* @param stream CUDA stream
*/
void cudaGraphGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, void** ptrBGpu,
void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, bool isLoraIn,
nvinfer1::DataType dataType, int minKN, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream);
/**
* @brief CUDA Graph compatible wrapper for split-K grouped GEMM operations.
*
* Similar to cudaGraphGroupedGemm but uses split-K algorithm for better
* performance with certain problem sizes. No parameter workspace needed.
*/
void cudaGraphSplitKGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
bool isLoraIn, nvinfer1::DataType dataType, int splitKSlices, int minKN,
cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, int64_t* splitKOffsetsGpu, cudaStream_t stream);
} // namespace kernels
TRTLLM_NAMESPACE_END

View File

@ -299,7 +299,7 @@ int LoraImpl::run(int64_t numTokens, int64_t numReqs, void const* input, int32_t
+ (loraModuleIdx * numTokens * mMaxLowRank + handled_token_num * mMaxLowRank) * typeSize));
auto const N2 = mOutHiddenSizes[loraModuleIdx];
cutlass::gemm::GemmCoord problem_2(M, N2, N);
cutlass::gemm::GemmCoord problem_2(M, N2, N); // token_num, module_output_size, lora_rank
problem_sizes_2.push_back(problem_2);
ptrA_2.push_back(static_cast<void*>(static_cast<char*>(lowRankWorkSpace)
+ (loraModuleIdx * numTokens * mMaxLowRank + handled_token_num * mMaxLowRank) * typeSize));

View File

@ -0,0 +1,408 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "loraGroupGEMMParamFillRowReorderFusion.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include <algorithm>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
namespace
{
template <class T>
__forceinline__ constexpr T (&as_singleton_array(T& obj))[1]
{
return reinterpret_cast<T(&)[1]>(obj);
}
enum ParamIndex
{
IN_SIZES_INDEX,
OUT_SIZES_INDEX,
LDA_INDEX,
LDD_INDEX,
LDB_PRIME_INDEX,
LDD_PRIME_INDEX,
A_PTRS_INDEX,
D_PTRS_INDEX,
D_PRIME_PTRS_INDEX,
SPLITK_OFFSETS_INDEX,
PARAM_COUNT,
};
int constexpr VECTOR_LOAD_WIDTH = 16;
} // namespace
/**
* Fused kernel for LoRA group GEMM parameter filling, row gather and zero fillings.
* Needs to be called with at least PARAM_COUNT blocks. And total number of threads need to be enough to reorder `input`
* and fill zeros for intermediate and output buffers. Specifically, (n_threads >= max(input_bytes_to_reorder,
* intermediate_bytes_to_fill, output_bytes_to_fill) / VECTOR_LOAD_WIDTH)
*
* Template parameters:
* - BlockDim: Number of threads per block (1D, >= max_lora_count * module_count, >= 256, divisible by 32)
* - MODULE_COUNT: Number of modules per layer
*/
template <int BlockDim, int MODULE_COUNT>
__global__ void loraGroupGEMMParamFillRowReorderFusionKernel(
// Output parameters
int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs, int64_t* d_ptrs, int64_t* d_prime_ptrs, int64_t* lda,
int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime, int64_t* splitk_offsets, uint8_t* reordered_input,
// Input parameters
int32_t max_lora_count, int32_t max_lora_rank, int32_t sum_output_hidden_size, int32_t input_hidden_size,
int64_t dtype_element_size, int64_t batch_size, int64_t a_base, int64_t d_base, int64_t d_prime_base,
int32_t const* slot_counts, int32_t const* slot_ranks, int64_t const* slot_offsets, int32_t const* module_out_sizes,
int64_t const* module_out_prefix, int64_t const* b_ptrs, int64_t const* b_prime_ptrs, uint8_t const* input,
int64_t const* sorted_ids)
{
int const linearIdx = threadIdx.x;
int const blockLinearIdx = blockIdx.x + blockIdx.y * gridDim.x;
int constexpr THREADS_PER_BLOCK = BlockDim;
// Calculate lora_id and module_id from linearIdx
int const lora_id = linearIdx % max_lora_count;
int const module_id = linearIdx / max_lora_count;
using BlockLoad = cub::BlockLoad<int32_t, BlockDim, 1, cub::BLOCK_LOAD_DIRECT>;
using BlockStore = cub::BlockStore<int32_t, BlockDim, 1, cub::BLOCK_STORE_DIRECT>;
using BlockLoad64 = cub::BlockLoad<int64_t, BlockDim, 1, cub::BLOCK_LOAD_DIRECT>;
using BlockStore64 = cub::BlockStore<int64_t, BlockDim, 1, cub::BLOCK_STORE_DIRECT>;
using BlockScan = cub::BlockScan<int64_t, BlockDim, cub::BLOCK_SCAN_WARP_SCANS>;
using BlockStore3 = cub::BlockStore<int32_t, BlockDim, 3, cub::BLOCK_STORE_WARP_TRANSPOSE>;
__shared__ union
{
typename BlockStore3::TempStorage storage3;
typename BlockScan::TempStorage scan;
} large_shared;
__shared__ int32_t
row_count_to_gather; // rows > row_count_to_gather belong to non-LoRA requests, write zero directly
switch (blockLinearIdx)
{
case IN_SIZES_INDEX:
{
int32_t slot_count = slot_counts[lora_id];
int32_t rank = slot_ranks[lora_id];
int64_t b_ptr = b_ptrs[linearIdx % (max_lora_count * MODULE_COUNT)];
int32_t row[3] = {0};
if (b_ptr != 0)
{
row[0] = slot_count;
row[1] = rank;
row[2] = input_hidden_size;
}
BlockStore3(large_shared.storage3).Store(in_sizes, row, max_lora_count * MODULE_COUNT * 3);
}
break;
case OUT_SIZES_INDEX:
{
int32_t slot_count = slot_counts[lora_id];
int32_t output_hidden_size = module_out_sizes[module_id];
int32_t rank = slot_ranks[lora_id];
int64_t b_ptr = b_ptrs[linearIdx % (max_lora_count * MODULE_COUNT)];
int32_t row[3] = {0};
if (b_ptr != 0)
{
row[0] = slot_count;
row[1] = output_hidden_size;
row[2] = rank;
}
BlockStore3(large_shared.storage3).Store(out_sizes, row, max_lora_count * MODULE_COUNT * 3);
}
break;
case LDA_INDEX:
{
int64_t input_hidden_size_64[1] = {static_cast<int64_t>(input_hidden_size)};
BlockStore64().Store(lda, input_hidden_size_64, max_lora_count * MODULE_COUNT);
}
break;
case LDD_INDEX:
{
int64_t max_lora_rank_64[1] = {static_cast<int64_t>(max_lora_rank)};
BlockStore64().Store(ldd, max_lora_rank_64, max_lora_count * MODULE_COUNT);
}
break;
case LDB_PRIME_INDEX:
{
int64_t rank = slot_ranks[lora_id];
BlockStore64().Store(ldb_prime, as_singleton_array(rank), max_lora_count * MODULE_COUNT);
}
break;
case LDD_PRIME_INDEX:
{
int64_t sum_output_hidden_size_64[1] = {static_cast<int64_t>(sum_output_hidden_size)};
BlockStore64().Store(ldd_prime, sum_output_hidden_size_64, max_lora_count * MODULE_COUNT);
}
break;
case A_PTRS_INDEX:
{
int64_t slot_offset = 0;
BlockLoad64().Load(slot_offsets, as_singleton_array(slot_offset), max_lora_count + 1);
if (linearIdx == max_lora_count)
{
row_count_to_gather = static_cast<int32_t>(slot_offset);
}
slot_offset *= input_hidden_size;
slot_offset *= dtype_element_size;
slot_offset += a_base;
for (int i = 0; i < MODULE_COUNT; i++)
{
BlockStore64().Store(a_ptrs + i * max_lora_count, as_singleton_array(slot_offset), max_lora_count);
}
}
break;
case D_PTRS_INDEX:
{
int64_t slot_offset = 0;
BlockLoad64().Load(slot_offsets, as_singleton_array(slot_offset), max_lora_count + 1);
for (int i = 0; i < MODULE_COUNT; i++)
{
int64_t offset = slot_offset;
offset += i * batch_size;
offset *= max_lora_rank;
offset *= dtype_element_size;
offset += d_base;
BlockStore64().Store(d_ptrs + i * max_lora_count, as_singleton_array(offset), max_lora_count);
}
if (linearIdx == max_lora_count)
{
row_count_to_gather = static_cast<int32_t>(slot_offset);
}
}
break;
case D_PRIME_PTRS_INDEX:
{
int64_t slot_offset = 0;
BlockLoad64().Load(slot_offsets, as_singleton_array(slot_offset), max_lora_count + 1);
if (linearIdx == max_lora_count)
{
row_count_to_gather = static_cast<int32_t>(slot_offset);
}
slot_offset *= sum_output_hidden_size;
for (int i = 0; i < MODULE_COUNT; i++)
{
int64_t offset = slot_offset;
offset += module_out_prefix[i];
offset *= dtype_element_size;
offset += d_prime_base;
BlockStore64().Store(d_prime_ptrs + i * max_lora_count, as_singleton_array(offset), max_lora_count);
}
}
break;
case SPLITK_OFFSETS_INDEX:
{
int64_t slot_count = slot_counts[lora_id];
int64_t rank = slot_ranks[lora_id];
int64_t b_ptr = b_ptrs[linearIdx % (max_lora_count * MODULE_COUNT)];
int64_t splitk_offset = (b_ptr == 0) ? 0 : (slot_count * rank);
BlockScan(large_shared.scan).ExclusiveSum(splitk_offset, splitk_offset);
BlockStore64().Store(splitk_offsets, as_singleton_array(splitk_offset), max_lora_count * MODULE_COUNT);
}
break;
}
// Set row_count_to_gather for non-pointer blocks
switch (blockLinearIdx)
{
case A_PTRS_INDEX:
case D_PTRS_INDEX:
case D_PRIME_PTRS_INDEX: break;
default:
if (linearIdx == 0)
{
row_count_to_gather = static_cast<int32_t>(slot_offsets[max_lora_count]);
}
}
int constexpr ITEM_PER_THREAD = VECTOR_LOAD_WIDTH;
using BlockStoreRow = cub::BlockStore<uint8_t, BlockDim, ITEM_PER_THREAD, cub::BLOCK_STORE_VECTORIZE>;
{
// Write zero to intermediate buffer and output buffer
auto intermediate_cast = reinterpret_cast<uint8_t*>(d_base);
auto model_output_cast = reinterpret_cast<uint8_t*>(d_prime_base);
int intermediate_size = MODULE_COUNT * batch_size * max_lora_rank * dtype_element_size;
int output_size = batch_size * sum_output_hidden_size * dtype_element_size;
uint8_t all_zeroes[ITEM_PER_THREAD] = {0};
int const blockOffset = THREADS_PER_BLOCK * ITEM_PER_THREAD * blockLinearIdx;
BlockStoreRow().Store(intermediate_cast + blockOffset, all_zeroes, intermediate_size - blockOffset);
BlockStoreRow().Store(model_output_cast + blockOffset, all_zeroes, output_size - blockOffset);
}
__syncthreads();
// Row gather
if (blockIdx.y < batch_size)
{
using BlockLoadRow = cub::BlockLoad<uint8_t, BlockDim, ITEM_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE>;
auto const row_size = input_hidden_size * dtype_element_size;
auto output_cast = reinterpret_cast<uint8_t*>(reordered_input);
uint8_t tile[ITEM_PER_THREAD] = {0};
int constexpr x_stride = THREADS_PER_BLOCK * ITEM_PER_THREAD;
int const y_stride = row_size;
int tail = row_size - blockIdx.x * x_stride;
if (blockIdx.y < row_count_to_gather)
{
auto const input_cast = reinterpret_cast<uint8_t const*>(input);
auto const src_row = sorted_ids[blockIdx.y];
BlockLoadRow().Load(input_cast + blockIdx.x * x_stride + src_row * y_stride, tile, tail);
}
BlockStoreRow().Store(output_cast + blockIdx.x * x_stride + blockIdx.y * y_stride, tile, tail);
}
}
/**
* Launch function that instantiates the appropriate kernel based on module count.
*/
template <int BlockDim>
void launchKernelWithModuleCount(int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs, int64_t* d_ptrs,
int64_t* d_prime_ptrs, int64_t* lda, int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime, int64_t* splitk_offsets,
void* reordered_input, int32_t max_lora_count, int32_t max_lora_rank, int32_t sum_output_hidden_size,
int32_t input_hidden_size, int64_t dtype_element_size, int64_t batch_size, int64_t a_base, int64_t d_base,
int64_t d_prime_base, int32_t const* slot_counts, int32_t const* slot_ranks, int64_t const* slot_offsets,
int32_t const* module_out_sizes, int64_t const* module_out_prefix, int64_t const* b_ptrs,
int64_t const* b_prime_ptrs, void const* input, int64_t const* sorted_ids, int32_t module_count,
cudaStream_t stream)
{
int constexpr THREADS_PER_BLOCK = BlockDim;
// Grid dimensions for row gather
int constexpr ITEMS_PER_BLOCK = THREADS_PER_BLOCK * VECTOR_LOAD_WIDTH;
int const gridDimX = common::ceilDiv(input_hidden_size * dtype_element_size, ITEMS_PER_BLOCK);
int gridDimY = std::max(
static_cast<int>(common::ceilDiv(static_cast<int>(PARAM_COUNT), gridDimX)), static_cast<int>(batch_size));
// calculate threads needed for writing zeros to intermediate buffer and output buffer
int const itemsPerRow = ITEMS_PER_BLOCK * gridDimX;
gridDimY = std::max(gridDimY,
common::ceilDiv(static_cast<int>(module_count * batch_size * max_lora_rank * dtype_element_size), itemsPerRow));
gridDimY = std::max(gridDimY,
common::ceilDiv(static_cast<int>(batch_size * sum_output_hidden_size * dtype_element_size), itemsPerRow));
dim3 grid(gridDimX, gridDimY);
dim3 block(BlockDim);
auto* reordered_input_cast = reinterpret_cast<uint8_t*>(reordered_input);
auto const* input_cast = reinterpret_cast<uint8_t const*>(input);
// Dispatch based on module count
switch (module_count)
{
case 1:
loraGroupGEMMParamFillRowReorderFusionKernel<BlockDim, 1><<<grid, block, 0, stream>>>(in_sizes, out_sizes,
a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast,
max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size,
a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix,
b_ptrs, b_prime_ptrs, input_cast, sorted_ids);
break;
case 2:
loraGroupGEMMParamFillRowReorderFusionKernel<BlockDim, 2><<<grid, block, 0, stream>>>(in_sizes, out_sizes,
a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast,
max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size,
a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix,
b_ptrs, b_prime_ptrs, input_cast, sorted_ids);
break;
case 3:
loraGroupGEMMParamFillRowReorderFusionKernel<BlockDim, 3><<<grid, block, 0, stream>>>(in_sizes, out_sizes,
a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast,
max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size,
a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix,
b_ptrs, b_prime_ptrs, input_cast, sorted_ids);
break;
case 4:
loraGroupGEMMParamFillRowReorderFusionKernel<BlockDim, 4><<<grid, block, 0, stream>>>(in_sizes, out_sizes,
a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast,
max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size,
a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix,
b_ptrs, b_prime_ptrs, input_cast, sorted_ids);
break;
default: TLLM_CHECK_WITH_INFO(false, "Unsupported module_count: %d (max 4)", module_count);
}
}
void launchLoraGroupGEMMParamFillRowReorderFusion(int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs,
int64_t* d_ptrs, int64_t* d_prime_ptrs, int64_t* lda, int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime,
int64_t* splitk_offsets, void* reordered_input, int32_t max_lora_count, int32_t max_lora_rank,
int32_t sum_output_hidden_size, int32_t input_hidden_size, int64_t dtype_element_size, int64_t batch_size,
int64_t a_base, int64_t d_base, int64_t d_prime_base, int32_t const* slot_counts, int32_t const* slot_ranks,
int64_t const* slot_offsets, int32_t const* module_out_sizes, int64_t const* module_out_prefix,
int64_t const* b_ptrs, int64_t const* b_prime_ptrs, void const* input, int64_t const* sorted_ids,
int32_t module_count, nvinfer1::DataType dtype, cudaStream_t stream)
{
// Determine block dimensions (1D)
// Requirements: 1) >= max_lora_count * module_count 2) >= 256 3) divisible by 32
int constexpr MIN_THREADS = 256;
int constexpr WARP_SIZE = 32;
int const min_threads_needed = max_lora_count * module_count;
int const threads_per_block = std::max(MIN_THREADS, common::ceilDiv(min_threads_needed, WARP_SIZE) * WARP_SIZE);
if (threads_per_block == 256)
{
launchKernelWithModuleCount<256>(in_sizes, out_sizes, a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime,
ldd_prime, splitk_offsets, reordered_input, max_lora_count, max_lora_rank, sum_output_hidden_size,
input_hidden_size, dtype_element_size, batch_size, a_base, d_base, d_prime_base, slot_counts, slot_ranks,
slot_offsets, module_out_sizes, module_out_prefix, b_ptrs, b_prime_ptrs, input, sorted_ids, module_count,
stream);
}
else if (threads_per_block == 288)
{
launchKernelWithModuleCount<288>(in_sizes, out_sizes, a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime,
ldd_prime, splitk_offsets, reordered_input, max_lora_count, max_lora_rank, sum_output_hidden_size,
input_hidden_size, dtype_element_size, batch_size, a_base, d_base, d_prime_base, slot_counts, slot_ranks,
slot_offsets, module_out_sizes, module_out_prefix, b_ptrs, b_prime_ptrs, input, sorted_ids, module_count,
stream);
}
else if (threads_per_block == 320)
{
launchKernelWithModuleCount<320>(in_sizes, out_sizes, a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime,
ldd_prime, splitk_offsets, reordered_input, max_lora_count, max_lora_rank, sum_output_hidden_size,
input_hidden_size, dtype_element_size, batch_size, a_base, d_base, d_prime_base, slot_counts, slot_ranks,
slot_offsets, module_out_sizes, module_out_prefix, b_ptrs, b_prime_ptrs, input, sorted_ids, module_count,
stream);
}
else
{
TLLM_CHECK_WITH_INFO(false,
"Unsupported threads_per_block: %d (calculated from max_lora_count=%d * module_count=%d)",
threads_per_block, max_lora_count, module_count);
}
sync_check_cuda_error(stream);
}
} // namespace kernels
TRTLLM_NAMESPACE_END

View File

@ -0,0 +1,77 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/config.h"
#include <NvInferRuntime.h>
#include <cstdint>
#include <cuda_runtime.h>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
/**
* @brief Fused kernel that fills group GEMM parameters, performs row reordering, zero fillings for CUDA graph
* compatible LoRA.
*
* @param in_sizes Output: [module_count, max_lora_count, 3] problem sizes for first GEMM
* @param out_sizes Output: [module_count, max_lora_count, 3] problem sizes for second GEMM
* @param a_ptrs Output: [module_count, max_lora_count] input matrix pointers
* @param d_ptrs Output: [module_count, max_lora_count] intermediate output pointers
* @param d_prime_ptrs Output: [module_count, max_lora_count] final output pointers
* @param lda Output: [module_count, max_lora_count] leading dimensions for A matrices
* @param ldd Output: [module_count, max_lora_count] leading dimensions for D matrices
* @param ldb_prime Output: [module_count, max_lora_count] leading dimensions for B' matrices
* @param ldd_prime Output: [module_count, max_lora_count] leading dimensions for D' matrices
* @param splitk_offsets Output: [module_count, max_lora_count] split-K work offsets
* @param reordered_input Output: [batch_size, input_hidden_size] reordered input matrix
* @param max_lora_count Maximum number of LoRA adapters
* @param max_lora_rank Maximum rank of LoRA adapters
* @param sum_output_hidden_size Sum of output hidden sizes across modules
* @param input_hidden_size Input hidden dimension
* @param dtype_element_size Size of data type in bytes
* @param batch_size Batch size
* @param a_base Base pointer for input matrices
* @param d_base Base pointer for intermediate output matrices
* @param d_prime_base Base pointer for final output matrices
* @param slot_counts Input: [max_lora_count] number of requests per LoRA slot
* @param slot_ranks Input: [max_lora_count] rank of each LoRA adapter
* @param slot_offsets Input: [max_lora_count + 1] cumulative offsets (last element = total rows)
* @param module_out_sizes Input: [module_count] output hidden size per module
* @param module_out_prefix Input: [module_count] prefix sum of output hidden sizes
* @param b_ptrs Input: [module_count, max_lora_count] weight pointers for first GEMM
* @param b_prime_ptrs Input: [module_count, max_lora_count] weight pointers for second GEMM
* @param input Input: [batch_size, input_hidden_size] original input matrix
* @param sorted_ids Input: [batch_size] indices for row reordering
* @param module_count Number of modules per layer
* @param dtype Data type of matrices
* @param stream CUDA stream
*/
void launchLoraGroupGEMMParamFillRowReorderFusion(int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs,
int64_t* d_ptrs, int64_t* d_prime_ptrs, int64_t* lda, int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime,
int64_t* splitk_offsets, void* reordered_input, int32_t max_lora_count, int32_t max_lora_rank,
int32_t sum_output_hidden_size, int32_t input_hidden_size, int64_t dtype_element_size, int64_t batch_size,
int64_t a_base, int64_t d_base, int64_t d_prime_base, int32_t const* slot_counts, int32_t const* slot_ranks,
int64_t const* slot_offsets, int32_t const* module_out_sizes, int64_t const* module_out_prefix,
int64_t const* b_ptrs, int64_t const* b_prime_ptrs, void const* input, int64_t const* sorted_ids,
int32_t module_count, nvinfer1::DataType dtype, cudaStream_t stream);
} // namespace kernels
TRTLLM_NAMESPACE_END

View File

@ -34,6 +34,7 @@
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/tuple.h>
#include <nanobind/stl/unique_ptr.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>
#include <nanobind/trampoline.h>
#include <torch/extension.h>
@ -563,6 +564,11 @@ void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m)
nb::arg("config"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"),
nb::call_guard<nb::gil_scoped_release>())
.def("is_task_cached", &tb::PeftCacheManager::isTaskCached, nb::arg("taskId"),
nb::call_guard<nb::gil_scoped_release>())
.def("is_task_cached_device", &tb::PeftCacheManager::isTaskCachedDevice, nb::arg("taskId"),
nb::call_guard<nb::gil_scoped_release>()) // ;
.def("ensure_batch_map_task_id", &tb::PeftCacheManager::ensureBatchMapTaskId, nb::arg("context_requests"),
nb::arg("generation_requests"), nb::arg("reset_gpu_cache") = false,
nb::call_guard<nb::gil_scoped_release>());
nb::class_<tb::NoOpPeftCacheManager, tb::BasePeftCacheManager>(m, "NoOpPeftCacheManager")

View File

@ -558,6 +558,11 @@ void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"),
py::call_guard<py::gil_scoped_release>())
.def("is_task_cached", &tb::PeftCacheManager::isTaskCached, py::arg("taskId"),
py::call_guard<py::gil_scoped_release>())
.def("is_task_cached_device", &tb::PeftCacheManager::isTaskCachedDevice, py::arg("taskId"),
py::call_guard<py::gil_scoped_release>())
.def("ensure_batch_map_task_id", &tb::PeftCacheManager::ensureBatchMapTaskId, py::arg("context_requests"),
py::arg("generation_requests"), py::arg("reset_gpu_cache") = false,
py::call_guard<py::gil_scoped_release>());
py::classh<tb::NoOpPeftCacheManager, tb::BasePeftCacheManager>(m, "NoOpPeftCacheManager")

View File

@ -42,7 +42,7 @@ public:
using LoraReqTensors = std::tuple<LoraWeightsTensorPtr, LoraConfigTensorPtr>;
using TaskIdType = std::int64_t;
using PeftValues = std::vector<runtime::LoraCache::TaskLayerModuleConfig>;
using PeftTable = std::map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
using PeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
explicit LoraManager() {}

View File

@ -18,13 +18,16 @@
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/cuda_graph_grouped_gemm.h"
#include "tensorrt_llm/kernels/lora/lora.h"
#include "tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.h"
#include "tensorrt_llm/kernels/selectiveScan/selectiveScan.h"
#include "tensorrt_llm/thop/thUtils.h"
namespace th = torch;
namespace tk = tensorrt_llm::kernels;
using tensorrt_llm::common::fmtstr;
namespace tc = tensorrt_llm::common;
TRTLLM_NAMESPACE_BEGIN
@ -174,6 +177,171 @@ std::vector<th::Tensor> lora_grouped_gemm(th::Tensor const& input, th::Tensor co
return output_torch;
}
void lora_grouped_gemm_cuda_graph(th::Tensor const& lora_in_sizes, // [layer_module_num, max_lora_size, 3]
th::Tensor const& lora_out_sizes, // [layer_module_num, max_lora_size, 3]
th::Tensor const& a_offsets, // [layer_module_num, max_lora_size]
th::Tensor const& b_ptrs, // [layer_module_num, max_lora_size]
th::Tensor const& d_offsets, // [layer_module_num, max_lora_size]
th::Tensor const& b_prime_ptrs, // [layer_module_num, max_lora_size]
th::Tensor const& d_prime_offsets, // [layer_module_num, max_lora_size]
int64_t problem_count,
th::Tensor const& lda, // Leading dimensions for A matrices [layer_module_num, max_lora_size]
th::Tensor const& ldb, // Leading dimensions for B matrices [layer_module_num, max_lora_size]
th::Tensor const& ldd, // Leading dimensions for C matrices [layer_module_num, max_lora_size] (unused)
th::Tensor const& ldb_prime, th::Tensor const& ldd_prime, th::Tensor const& host_max_in_sizes,
th::Tensor const& host_max_out_sizes, th::Tensor const& splitk_offsets, c10::ScalarType dtype, int64_t minKN,
int64_t splitKSlices = 16)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto stream = at::cuda::getCurrentCUDAStream().stream();
sync_check_cuda_error(stream);
auto* a_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(a_offsets.data_ptr()));
auto* d_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(d_offsets.data_ptr()));
auto* a_prime_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(d_offsets.data_ptr()));
auto* d_prime_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(d_prime_offsets.data_ptr()));
auto* problem_sizes_1_ptr = reinterpret_cast<cutlass::gemm::GemmCoord*>(lora_in_sizes.data_ptr());
auto* problem_sizes_2_ptr = reinterpret_cast<cutlass::gemm::GemmCoord*>(lora_out_sizes.data_ptr());
auto* host_max_in_sizes_ptr = reinterpret_cast<cutlass::gemm::GemmCoord*>(host_max_in_sizes.data_ptr());
auto* host_max_out_sizes_ptr = reinterpret_cast<cutlass::gemm::GemmCoord*>(host_max_out_sizes.data_ptr());
auto* b_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(b_ptrs.data_ptr()));
auto* b_prime_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(b_prime_ptrs.data_ptr()));
auto* lda_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(lda.data_ptr()));
auto* ldb_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(ldb.data_ptr()));
auto* ldd_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(ldd.data_ptr()));
auto* ldb_prime_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(ldb_prime.data_ptr()));
auto* ldd_prime_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(ldd_prime.data_ptr()));
auto* splitk_offsets_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(splitk_offsets.data_ptr()));
// Get data type
nvinfer1::DataType loraRuntimeDataType;
switch (dtype)
{
case torch::kFloat16: loraRuntimeDataType = nvinfer1::DataType::kHALF; break;
case torch::kBFloat16: loraRuntimeDataType = nvinfer1::DataType::kBF16; break;
default: TORCH_CHECK(false, "Invalid dtype, only supports float16, bfloat16, got %s", c10::toString(dtype));
}
int const minKnInt = std::max(1, static_cast<int>(minKN));
// Call CUDA Graph compatible grouped GEMM for lora_in (split-K)
if (problem_count > 0)
{
TLLM_LOG_TRACE("Start Grouped GEMM for LoRA in.");
tk::cudaGraphSplitKGroupedGemm(problem_sizes_1_ptr, problem_count, a_ptrs_gpu, b_ptrs_gpu,
d_ptrs_gpu, // ptrC (no bias)
d_ptrs_gpu, lda_gpu, ldb_gpu, ldd_gpu, ldd_gpu, // Precomputed leading dimensions
true, // isLoraIn
loraRuntimeDataType,
static_cast<int>(splitKSlices), // splitKSlices
minKnInt, // minKN
host_max_in_sizes_ptr, splitk_offsets_gpu, stream);
sync_check_cuda_error(stream);
// Call CUDA Graph compatible grouped GEMM for lora_out
TLLM_LOG_TRACE("Start Grouped GEMM for LoRA out.");
tk::cudaGraphGroupedGemm(problem_sizes_2_ptr, problem_count, a_prime_ptrs_gpu, b_prime_ptrs_gpu,
d_prime_ptrs_gpu, // ptrC (no bias)
d_prime_ptrs_gpu, ldd_gpu, ldb_prime_gpu, ldd_prime_gpu, ldd_prime_gpu, // Precomputed leading dimensions
false, // isLoraIn
loraRuntimeDataType,
minKnInt, // minKN
host_max_out_sizes_ptr, stream);
sync_check_cuda_error(stream);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void lora_group_gemm_param_fill_row_reorder_fusion(th::Tensor const& in_sizes, // [module_count, max_lora_count, 3]
th::Tensor const& out_sizes, // [module_count, max_lora_count, 3]
th::Tensor const& a_ptrs, // [module_count, max_lora_count]
th::Tensor const& d_ptrs, // [module_count, max_lora_count]
th::Tensor const& d_prime_ptrs, // [module_count, max_lora_count]
th::Tensor const& lda, // [module_count, max_lora_count]
th::Tensor const& ldd, // [module_count, max_lora_count]
th::Tensor const& ldb_prime, // [module_count, max_lora_count]
th::Tensor const& ldd_prime, // [module_count, max_lora_count]
th::Tensor const& splitk_offsets, // [module_count, max_lora_count]
th::Tensor const& reordered_input, // [batch_size, input_hidden_size]
int64_t max_lora_count, int64_t max_lora_rank, int64_t sum_output_hidden_size, int64_t input_hidden_size,
int64_t batch_size,
th::Tensor const& slot_counts, // [max_lora_count]
th::Tensor const& slot_ranks, // [max_lora_count]
th::Tensor const& slot_offsets, // [max_lora_count + 1]
th::Tensor const& module_out_sizes, // [module_count]
th::Tensor const& module_out_prefix, // [module_count]
th::Tensor const& b_ptrs, // [module_count, max_lora_count]
th::Tensor const& b_prime_ptrs, // [module_count, max_lora_count]
th::Tensor const& input, // [batch_size, input_hidden_size]
th::Tensor const& sorted_ids, // [batch_size]
th::Tensor const& intermediate_buffer, // [batch_size, max_lora_rank]
th::Tensor const& output_buffer, // [batch_size, sum_output_hidden_size]
c10::ScalarType dtype)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto stream = at::cuda::getCurrentCUDAStream().stream();
// Validate inputs
TORCH_CHECK(in_sizes.device().is_cuda(), "in_sizes must be a CUDA tensor");
TORCH_CHECK(out_sizes.device().is_cuda(), "out_sizes must be a CUDA tensor");
TORCH_CHECK(reordered_input.device().is_cuda(), "reordered_input must be a CUDA tensor");
TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor");
// Get module count from tensor shapes
int32_t const module_count = static_cast<int32_t>(in_sizes.size(0));
// Get data type info
nvinfer1::DataType loraRuntimeDataType;
switch (dtype)
{
case torch::kFloat16: loraRuntimeDataType = nvinfer1::DataType::kHALF; break;
case torch::kBFloat16: loraRuntimeDataType = nvinfer1::DataType::kBF16; break;
default: TORCH_CHECK(false, "Invalid dtype, only supports float16, bfloat16, got %s", c10::toString(dtype));
}
int64_t const dtype_element_size = input.element_size();
int64_t const a_base_ptr = reinterpret_cast<int64_t>(reordered_input.data_ptr());
int64_t const d_base_ptr = reinterpret_cast<int64_t>(intermediate_buffer.data_ptr());
int64_t const d_prime_base_ptr = reinterpret_cast<int64_t>(output_buffer.data_ptr());
tk::launchLoraGroupGEMMParamFillRowReorderFusion(reinterpret_cast<int32_t*>(const_cast<void*>(in_sizes.data_ptr())),
reinterpret_cast<int32_t*>(const_cast<void*>(out_sizes.data_ptr())),
reinterpret_cast<int64_t*>(const_cast<void*>(a_ptrs.data_ptr())),
reinterpret_cast<int64_t*>(const_cast<void*>(d_ptrs.data_ptr())),
reinterpret_cast<int64_t*>(const_cast<void*>(d_prime_ptrs.data_ptr())),
reinterpret_cast<int64_t*>(const_cast<void*>(lda.data_ptr())),
reinterpret_cast<int64_t*>(const_cast<void*>(ldd.data_ptr())),
reinterpret_cast<int64_t*>(const_cast<void*>(ldb_prime.data_ptr())),
reinterpret_cast<int64_t*>(const_cast<void*>(ldd_prime.data_ptr())),
reinterpret_cast<int64_t*>(const_cast<void*>(splitk_offsets.data_ptr())),
const_cast<void*>(reordered_input.data_ptr()), static_cast<int32_t>(max_lora_count),
static_cast<int32_t>(max_lora_rank), static_cast<int32_t>(sum_output_hidden_size),
static_cast<int32_t>(input_hidden_size), dtype_element_size, batch_size, a_base_ptr, d_base_ptr,
d_prime_base_ptr, reinterpret_cast<int32_t const*>(slot_counts.data_ptr()),
reinterpret_cast<int32_t const*>(slot_ranks.data_ptr()),
reinterpret_cast<int64_t const*>(slot_offsets.data_ptr()),
reinterpret_cast<int32_t const*>(module_out_sizes.data_ptr()),
reinterpret_cast<int64_t const*>(module_out_prefix.data_ptr()),
reinterpret_cast<int64_t const*>(b_ptrs.data_ptr()), reinterpret_cast<int64_t const*>(b_prime_ptrs.data_ptr()),
input.data_ptr(), reinterpret_cast<int64_t const*>(sorted_ids.data_ptr()), module_count, loraRuntimeDataType,
stream);
sync_check_cuda_error(stream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
} // namespace torch_ext
TRTLLM_NAMESPACE_END
@ -192,9 +360,65 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
"int max_low_rank, "
"int weight_index, "
"bool isRemoveInputPadding) -> Tensor[]");
m.def(
"lora_grouped_gemm_cuda_graph("
"Tensor lora_in_sizes, "
"Tensor lora_out_sizes, "
"Tensor a_offsets, "
"Tensor b_ptrs, "
"Tensor d_offsets, "
"Tensor b_prime_ptrs, "
"Tensor d_prime_offsets, "
"int problem_count, "
"Tensor lda, "
"Tensor ldb, "
"Tensor ldd, "
"Tensor ldb_prime, "
"Tensor ldd_prime, "
"Tensor host_max_in_sizes, "
"Tensor host_max_out_sizes, "
"Tensor splitk_offsets, "
"ScalarType dtype, "
"int minKN, "
"int splitKSlices=16) -> ()");
m.def(
"lora_group_gemm_param_fill_row_reorder_fusion("
"Tensor in_sizes, "
"Tensor out_sizes, "
"Tensor a_ptrs, "
"Tensor d_ptrs, "
"Tensor d_prime_ptrs, "
"Tensor lda, "
"Tensor ldd, "
"Tensor ldb_prime, "
"Tensor ldd_prime, "
"Tensor splitk_offsets, "
"Tensor reordered_input, "
"int max_lora_count, "
"int max_lora_rank, "
"int sum_output_hidden_size, "
"int input_hidden_size, "
"int batch_size, "
"Tensor slot_counts, "
"Tensor slot_ranks, "
"Tensor slot_offsets, "
"Tensor module_out_sizes, "
"Tensor module_out_prefix, "
"Tensor b_ptrs, "
"Tensor b_prime_ptrs, "
"Tensor input, "
"Tensor sorted_ids, "
"Tensor intermediate_buffer, "
"Tensor output_buffer, "
"ScalarType dtype) -> ()");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("lora_grouped_gemm", &tensorrt_llm::torch_ext::lora_grouped_gemm);
m.impl("lora_grouped_gemm_cuda_graph", &tensorrt_llm::torch_ext::lora_grouped_gemm_cuda_graph);
m.impl("lora_group_gemm_param_fill_row_reorder_fusion",
&tensorrt_llm::torch_ext::lora_group_gemm_param_fill_row_reorder_fusion);
}

View File

@ -317,9 +317,6 @@ class Attention(nn.Module):
self.fused_qkv_lora = LoraLayer([LoraModuleType.ATTENTION_QKV],
[self.q_size + 2 * self.kv_size])
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
# Whether to fuse RoPE into the attention OP.
# If true, RoPE will be applied in self.attn.forward.
# If false, RoPE will be applied in self.apply_rope.

View File

@ -0,0 +1,130 @@
"""
AdapterSlotManager for managing slots that stores LoRA indices.
"""
from collections import OrderedDict
from typing import List, Optional
from ...pyexecutor.resource_manager import PeftCacheManager
from ...pyexecutor.scheduler import RequestList
class AdapterSlotManager:
"""
Manages max_num_adapters ordered slots for distinct task_ids to enable CUDA Graph compatibility.
Each slot can hold one adapter (task_id) and maintains a consistent ordering that allows
the CUDA Graph to be captured with fixed buffer layouts.
"""
def __init__(self, max_num_adapters: int):
"""
Initialize the AdapterSlotManager.
Args:
max_num_adapters: Maximum number of LoRA adapters that can be active simultaneously
"""
self.max_num_adapters = max_num_adapters
# Slot management
self.slot2task: List[Optional[int]] = [None] * max_num_adapters
self.task2slot: OrderedDict[int, int] = OrderedDict() # represent LRU order
# State tracking
self.slots_changed = False
def find_free_slot(self) -> int:
"""
Find a free slot. Return slot_id if found, otherwise return None.
"""
return self.slot2task.index(None)
def remove_task(self, task_id: int) -> Optional[int]:
"""
Remove a task_id from slots. Return its slot_id if present otherwise return None.
"""
slot_id = self.task2slot.pop(task_id, None)
if slot_id is not None:
self.slots_changed = True
self.slot2task[slot_id] = None
return slot_id
def get_or_assign_task(self, task_id: int) -> tuple[int, Optional[int]]:
"""
Assign a task_id to a slot and do LRU eviction if necessary.
If already in any slot, update LRU order.
Return: pair (assigned slot_id, evicted task_id)
"""
evicted_task = None
if task_id in self.task2slot:
self.task2slot.move_to_end(task_id)
else:
self.slots_changed = True
if len(self.task2slot) < self.max_num_adapters:
free_slot = self.find_free_slot()
self.slot2task[free_slot] = task_id
self.task2slot[task_id] = free_slot
else:
# evict lru
evicted_task, evicted_slot = self.task2slot.popitem(last=False)
self.slot2task[evicted_slot] = task_id
self.task2slot[task_id] = evicted_slot
return self.task2slot[task_id], evicted_task
def remove_evicted_slots_in_cpp(self, peft_cache_manager: PeftCacheManager):
"""
Validate slots by removing tasks that are not cached in PeftCacheManager.
"""
for task_id in self.slot2task:
if task_id is not None:
if not peft_cache_manager.is_task_cached_device(task_id):
self.remove_task(task_id)
def update_slots(
self, requests: RequestList, peft_cache_manager: PeftCacheManager
) -> list[int]:
"""
Get slot mapping for all requests in a scheduled batch.
Args:
scheduled_requests: The scheduled requests for the current batch
Returns:
Dict mapping request_id to slot_id, with slot_id=max_num_adapters for base model requests
"""
# remove task evicted in PeftCacheManager in C++
self.remove_evicted_slots_in_cpp(peft_cache_manager)
# check if total number of unique tasks in the requests is not larger than max_num_adapters
tasks = [request.lora_task_id for request in requests]
unique_tasks = {t for t in tasks if t is not None}
assert len(unique_tasks) <= self.max_num_adapters, (
f"Batch with more unique LoRA adapters ({len(unique_tasks)}) than max_num_adapters={self.max_num_adapters} "
"is not supported"
)
# assign slots to tasks
for i, task in enumerate(tasks):
if task is None:
tasks[i] = self.max_num_adapters
else:
tasks[i], evicted_task = self.get_or_assign_task(task)
return tasks
def get_slot_to_task_mapping(self) -> tuple[Optional[int], ...]:
"""
Get current slot to task mapping.
Returns:
Tuple mapping slot_id to task_id (or None if slot is empty)
"""
return tuple(self.slot2task)
def has_slots_changed(self) -> bool:
"""Check if slot assignments have changed since last check."""
return self.slots_changed
def reset_slots_changed(self):
"""Reset the slots_changed flag."""
self.slots_changed = False

View File

@ -0,0 +1,175 @@
from typing import Dict, Optional
import torch
from ...._utils import nvtx_range
from ....logger import logger
from ....lora_manager import LoraManager, LoraModelConfig
from ...attention_backend.interface import AttentionMetadata
from ...pyexecutor.resource_manager import PeftCacheManager
from ...pyexecutor.scheduler import ScheduledRequests
from .adapter_slot_manager import AdapterSlotManager
from .cuda_graph_lora_params import CudaGraphLoraParams
from .layer import LoraLayer
class CudaGraphLoraManager:
"""
Manager that coordinates adapter slots and CUDA Graph compatible LoRA parameters.
This class bridges the gap between the current LoRA implementation and the new
CUDA Graph compatible design by managing adapter slots and preparing persistent
device tensors for group GEMM operations.
"""
def __init__(
self,
max_lora_size: int,
max_batch_size: int,
max_lora_rank: int,
model: torch.nn.Module,
lora_model_config: Optional[LoraModelConfig],
device: str = "cuda",
):
"""
Initialize the CUDA Graph LoRA manager.
Args:
max_lora_size: Maximum number of LoRA adapters that can be active
max_batch_size: Maximum batch size for CUDA graphs
max_lora_rank: Maximum LoRA rank across all layers
model: Model to get layerwise LoRA info
lora_model_config: LoRA model configuration
device: Device to allocate tensors on
"""
self.max_lora_size = max_lora_size
self.max_batch_size = max_batch_size
self.max_lora_rank = max_lora_rank
self.device = device
self.adapter_slot_manager = AdapterSlotManager(max_lora_size)
self.lora_model_config = lora_model_config
lora_target_modules = lora_model_config.lora_target_modules
self.target_modules_ids: Optional[tuple[int, ...]] = (
tuple(map(LoraManager.LORA_MODULE_IDS.__getitem__, lora_target_modules))
if bool(lora_target_modules)
else None
)
if not self.target_modules_ids:
logger.debug(
"No LoRA target modules provided in LoRA config, using all modules in PyTorch Module!"
)
# Single CudaGraphLoraParams instance for all batch sizes. Its shape will be allocated to accommodate the max
# batch size.
self.cuda_graph_lora_params: Optional[CudaGraphLoraParams] = None
self.layer_info: (
Dict[CudaGraphLoraParams.LoraLayerKey, CudaGraphLoraParams.LoraLayerInfo] | None
) = None
# Initialize layer_info from model.
self._initialize_from_model(model)
self.cuda_graph_lora_params = CudaGraphLoraParams(
max_batch_size=self.max_batch_size,
max_lora_size=self.max_lora_size,
max_rank=self.max_lora_rank,
layer_info=self.layer_info,
device=self.device,
)
def _initialize_from_model(self, model: torch.nn.Module):
"""
Initialize LoRALayerInfo from model.
"""
self.layer_info = dict()
def get_layer_idx(
model: torch.nn.Module, lora_module: LoraLayer, lora_module_name: str
) -> Optional[int]:
"""Find the layer index of the given LoRA module in the model."""
module = lora_module
name = lora_module_name
while module is not None and (
(not hasattr(module, "layer_idx")) or module.layer_idx is None
):
name = name.rsplit(".", 1)
name = name[0] if len(name) > 1 else None
if name is not None:
module = model.get_submodule(name)
else:
module = None
if hasattr(module, "layer_idx") and module.layer_idx is not None:
return module.layer_idx
return None
# Ignore LoRA layers without at least one of the target modules.
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
layer_idx = get_layer_idx(model, module, name)
# if target_modules_ids is None, by default enable all modules
if self.target_modules_ids and not any(
module_id in self.target_modules_ids for module_id in module.lora_module_types
):
logger.debug(f"Layer {name} does not have any of the target modules, skipping")
continue
layer_key = CudaGraphLoraParams.LoraLayerKey(
layer_idx=layer_idx, module_ids=tuple(module.lora_module_types)
)
assert layer_key not in self.layer_info, f"Layer {layer_key} already exists"
self.layer_info[layer_key] = CudaGraphLoraParams.LoraLayerInfo(
module_num=len(module.lora_module_types),
output_sizes=module.output_hidden_sizes,
)
@nvtx_range("prepare_cuda_graph_lora_params")
def prepare_cuda_graph_lora_params(
self,
scheduled_requests: "ScheduledRequests",
attn_metadata: "AttentionMetadata",
peft_cache_manager: PeftCacheManager,
) -> Optional[Dict]:
"""
Prepare LoRA parameters from scheduled requests.
Args:
scheduled_requests: The scheduled requests for the current batch
attn_metadata: Attention metadata containing batch information
peft_table: PEFT table from cache manager mapping task_id to layer-module-configs
Returns:
LoRA parameters dictionary.
"""
assert len(scheduled_requests.context_requests) == 0, (
"Context requests are not supported with LoRA CUDA Graph path. "
f"Have {len(scheduled_requests.context_requests)} context requests"
)
request_list = scheduled_requests.generation_requests
peft_table = peft_cache_manager.get_and_reset_batch_peft_table()
# Update slot manager and get slot assignments for this batch
request_slot_ids = self.adapter_slot_manager.update_slots(request_list, peft_cache_manager)
cuda_graph_lora_params = self.cuda_graph_lora_params
cuda_graph_lora_params.update_sorted_indices(request_slot_ids)
# Get current slot to task mapping
slot2task = self.adapter_slot_manager.get_slot_to_task_mapping()
# Update weight pointers if slot assignments changed
if self.adapter_slot_manager.has_slots_changed():
cuda_graph_lora_params.update_weight_pointers(peft_table, slot2task)
self.adapter_slot_manager.reset_slots_changed()
# Update GEMM sizes and prefix sums using batch
cuda_graph_lora_params.update_slots_params(batch_slot_ids=request_slot_ids)
lora_params = {
"cuda_graph_params": cuda_graph_lora_params,
"host_request_types": attn_metadata.host_request_types,
"prompt_lens_cpu": attn_metadata.prompt_lens_cpu,
"num_seqs": attn_metadata.num_seqs,
"use_cuda_graph_mode": True, # Flag to indicate new mode
}
return lora_params

View File

@ -0,0 +1,341 @@
from collections import namedtuple
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
@dataclass
class LoraLayerParams:
"""
Parameters for a single LoRA layer.
All tensors are persistent device tensors that can be updated outside of graph replay.
"""
# Weight pointers
# Shape: [layer_module_num, max_lora_size]
d_b_ptrs: torch.Tensor # Lora_in weight pointers
d_b_prime_ptrs: torch.Tensor # Lora_out weight pointers
h_b_ptrs: torch.Tensor # Lora_in weight pointers in host
h_b_prime_ptrs: torch.Tensor # Lora_out weight pointers in host
d_output_sizes: torch.Tensor
d_output_sizes_offset: torch.Tensor
h_output_sizes: torch.Tensor
h_output_sizes_offset: torch.Tensor
class CudaGraphLoraParams:
"""
CUDA Graph compatible LoRA parameters for all layers and batch management.
This structure maintains persistent device tensors that can be updated outside
of CUDA Graph replay to support different LoRA combinations per batch.
"""
LoraLayerKey = namedtuple("LoraLayerKey", ["layer_idx", "module_ids"])
PTR_DTYPE = torch.int64
LD_DTYPE = torch.int64
SIZES_DTYPE = torch.int32
@dataclass
class LoraLayerInfo:
module_num: int = 0
output_sizes: List[int] | torch.Tensor | None = None
input_hidden_size: int = 0
def is_enabled(self) -> bool:
return self.input_hidden_size > 0
def __init__(
self,
max_batch_size: int,
max_lora_size: int,
max_rank: int,
layer_info: Dict[LoraLayerKey, LoraLayerInfo],
device: str = "cuda",
):
"""
Initialize CUDA Graph compatible LoRA parameters.
Args:
max_batch_size: Maximum batch size for this graph
max_lora_size: Maximum number of LoRA adapters
max_rank: Maximum rank for all layers
layers_info: Layer information for each layer
device: Device to allocate tensors on
dtype: Data type for size and offset tensors
"""
self.max_batch_size = max_batch_size
self.max_lora_size = max_lora_size
self.max_rank = max_rank
self.layer_info = layer_info
self.layer_module2key = self._calculate_layer_module2key()
self.device = device
self.layer_params: Dict[self.LoraLayerKey, LoraLayerParams] = dict()
# sorted indices using slot ids as keys, mainly to group requests with the same slot id together in a batch
self.sorted_ids = torch.zeros(max_batch_size, dtype=torch.int64, device=device)
self.sorted_ids_host = torch.zeros_like(self.sorted_ids, device="cpu", pin_memory=True)
# persistent values for gen-only batch with cuda graph
self.persistent_sorted_ids = self.sorted_ids
self.slot_ids = torch.zeros(max_batch_size, dtype=torch.int64, device=device)
self.slot_counts = torch.zeros(max_lora_size, dtype=torch.int32, device=device)
self.slot_counts_host = torch.zeros_like(self.slot_counts, device="cpu", pin_memory=True)
self.slot_offsets_full = torch.zeros(max_lora_size + 1, dtype=torch.int64, device=device)
self.slot_offsets = self.slot_offsets_full[:-1]
self.slot_offsets_full_host = torch.zeros_like(
self.slot_offsets_full, device="cpu", pin_memory=True
)
self.slot_ranks = torch.zeros(max_lora_size, dtype=torch.int32, device=device)
self.slot_ranks_host = torch.zeros_like(self.slot_ranks, device="cpu", pin_memory=True)
for key, info in self.layer_info.items():
assert (
info.module_num > 0
and info.output_sizes is not None
and len(info.output_sizes) == info.module_num
)
# Allocate layer parameters
self.layer_params[key] = self._allocate_layer_params(
key, info.module_num, info.output_sizes
)
def _calculate_layer_module2key(self) -> Dict[Tuple[int, int], LoraLayerKey]:
layer_module2key = dict()
for key in self.layer_info.keys():
layer_id = key.layer_idx
module_ids = key.module_ids
for module_id in module_ids:
layer_module2key[(layer_id, module_id)] = key
return layer_module2key
def _allocate_layer_params(
self, key: LoraLayerKey, layer_module_num: int, module_output_sizes: torch.Tensor
) -> LoraLayerParams:
"""
Create LoraLayerParams for a specific layer.
Args:
key: Key of the layer
layer_module_num: Number of modules in this layer
module_output_sizes: Output sizes for each module in this layer
Returns:
LoraLayerParams for the specified layer
"""
# GEMM parameter tensors only need max_lora_size (no dummy slot for base model)
# Base model requests are handled separately and don't participate in GEMM operations
shape_2d = (layer_module_num, self.max_lora_size)
output_hidden_sizes = torch.tensor(module_output_sizes, dtype=self.SIZES_DTYPE)
output_hidden_sizes_device = output_hidden_sizes.to(device="cuda")
output_sizes_offset = self.get_offset_from_counts(output_hidden_sizes).to(
dtype=self.PTR_DTYPE
) # [num_layer_modules]
output_sizes_offset_device = output_sizes_offset.to(device="cuda")
return LoraLayerParams(
# Weight pointers - managed by PEFT cache manager
d_b_ptrs=torch.zeros(shape_2d, dtype=torch.int64, device=self.device),
d_b_prime_ptrs=torch.zeros(shape_2d, dtype=torch.int64, device=self.device),
h_b_ptrs=torch.zeros(shape_2d, dtype=torch.int64, pin_memory=True),
h_b_prime_ptrs=torch.zeros(shape_2d, dtype=torch.int64, pin_memory=True),
d_output_sizes=output_hidden_sizes_device,
d_output_sizes_offset=output_sizes_offset_device,
h_output_sizes=output_hidden_sizes,
h_output_sizes_offset=output_sizes_offset,
)
@staticmethod
def get_sorted_indices(slot_ids: List[int]) -> torch.Tensor:
"""
Get sorted indices for the given slot IDs.
"""
slot_ids = torch.tensor(slot_ids, dtype=torch.int64)
# Compute sorted indices for gather/scatter operations
sorted_slot_ids, sorted_indices = torch.sort(slot_ids, stable=True)
return sorted_indices
def update_sorted_indices(self, slot_ids: List[int]):
"""
Update slot IDs for the current batch and compute sorted indices.
Args:
slot_ids: List of slot IDs for each token in the batch
actual_batch_size: Actual batch size (may be less than max_batch_size)
"""
actual_batch_size = len(slot_ids)
assert actual_batch_size <= self.max_batch_size, (
f"Actual batch size {actual_batch_size} exceeds max {self.max_batch_size}"
)
sorted_indices = self.get_sorted_indices(slot_ids)
# Update sorted_ids tensor with the computed indices
assert actual_batch_size <= self.max_batch_size, (
f"CudaGraphLoraParams: Actual batch size {actual_batch_size} exceeds max {self.max_batch_size}!"
)
if actual_batch_size <= self.max_batch_size:
# if can fit in persistent, use it
self.sorted_ids = self.persistent_sorted_ids
sorted_ids_host = self.sorted_ids_host[:actual_batch_size]
sorted_ids_host.copy_(sorted_indices)
self.sorted_ids[:actual_batch_size].copy_(sorted_ids_host, non_blocking=True)
else:
# otherwise not an gen-only batch, use new allocated sorted_ids
self.sorted_ids = sorted_indices.to(device=self.device)
def update_weight_pointers(
self, peft_table: Dict[int, List], slot_to_task_mapping: tuple[Optional[int], ...]
):
"""
Update weight pointers from PEFT cache manager.
Args:
peft_table: PEFT table from cache manager containing weight pointers, map task id to list of layer
module configs
slot_to_task_mapping: Mapping from slot_id to task_id, tuple of None for empty slots
"""
# get slot ranks
# assume ranks are the same for a given slot,
# input_hidden_size are the same within a layer
# output sizes are the same for all slots with the same module
def zero_out_weight_pointers(slot_id: int):
"""
Zero out all weight pointers for a given slot_id for all layers
"""
for layer_param in self.layer_params.values():
layer_param.h_b_ptrs[:, slot_id] = 0
layer_param.h_b_prime_ptrs[:, slot_id] = 0
for slot_id in range(self.max_lora_size):
task_id = slot_to_task_mapping[slot_id]
if task_id is None: # empty slot
self.slot_ranks_host[slot_id] = 0
zero_out_weight_pointers(slot_id)
elif (
task_id not in peft_table
): # task has not changed in the slot, retain old rank / weight pointers
continue
else: # task might have changed in the slot, update its rank
task_configs = peft_table[task_id]
config = task_configs[0] # assume all layerModuleConfigs have the same rank
self.slot_ranks_host[slot_id] = config.adapter_size
zero_out_weight_pointers(
slot_id
) # in case new task in slot do not have LoRA adapter for some module in some layer
for config in task_configs:
layer_id = config.layer_id
module_id = config.module_id
key = self.layer_module2key[(layer_id, module_id)]
layer_param = self.layer_params[key]
local_module_id = key.module_ids.index(module_id)
assert key in self.layer_params, (
f"Layer {layer_id} not found in layer_params, assumption that all LoRA has their adapters on "
"the same layers is broken"
)
# Validate LoRA rank
rank = config.adapter_size
assert rank <= self.max_rank, (
f"LoRA rank {rank} in layer {layer_id} exceeds configured max_rank {self.max_rank}. "
)
layer_param.h_b_ptrs[local_module_id, slot_id] = config.weights_in_pointer
layer_param.h_b_prime_ptrs[local_module_id, slot_id] = (
config.weights_out_pointer
)
self.slot_ranks.copy_(self.slot_ranks_host, non_blocking=True)
for layer_param in self.layer_params.values():
layer_param.d_b_ptrs.copy_(layer_param.h_b_ptrs, non_blocking=True)
layer_param.d_b_prime_ptrs.copy_(layer_param.h_b_prime_ptrs, non_blocking=True)
@staticmethod
def get_offset_from_counts(
counts: torch.Tensor, full: bool = False, out: torch.Tensor = None
) -> torch.Tensor:
if out is None:
if full:
offset = torch.empty(counts.shape[0] + 1, dtype=torch.int64, device=counts.device)
else:
offset = torch.empty(counts.shape[0], dtype=torch.int64, device=counts.device)
else:
assert (full and out.shape[0] == counts.shape[0] + 1) or (
(not full) and out.shape[0] == counts.shape[0]
)
offset = out
offset[0] = 0
if full:
offset[1:] = counts
else:
offset[1:] = counts[:-1]
offset[1:].cumsum_(dim=0)
return offset
@staticmethod
def get_slot_counts(batch_slot_ids: List[int], max_lora_size: int) -> torch.Tensor:
"""
Get the number of tokens for each slot_id in the batch.
"""
slot_counts = torch.bincount(
torch.tensor(batch_slot_ids, dtype=torch.int32), minlength=max_lora_size
)
assert slot_counts.size(0) <= max_lora_size + 1
slot_counts = slot_counts[:max_lora_size]
return slot_counts
def update_slots_params(self, batch_slot_ids: List[int]):
"""
Update GEMM sizes and buffer offsets based on current batch composition.
Args:
batch_slot_ids: Slot IDs for each token in the batch
"""
slot_counts = self.get_slot_counts(batch_slot_ids, self.max_lora_size)
self.slot_counts_host.copy_(slot_counts)
self.get_offset_from_counts(slot_counts, full=True, out=self.slot_offsets_full_host)
self.slot_counts.copy_(self.slot_counts_host, non_blocking=True)
self.slot_offsets_full.copy_(self.slot_offsets_full_host, non_blocking=True)
def get_problem_count(self, layer_key: LoraLayerKey) -> int:
"""
Get the number of GEMM problems for a layer.
Args:
layer_key: Key of the layer
Returns:
Number of GEMM problems (layer_module_num * max_lora_size)
Returns 0 if layer has no LoRA modules
Note: Only actual LoRA slots are counted, not the dummy base model slot
"""
if layer_key not in self.layer_params:
return 0 # Layer has no LoRA modules
return self.layer_info[layer_key].module_num * self.max_lora_size
def get_layer_params(self, layer_key: LoraLayerKey) -> Optional[LoraLayerParams]:
"""
Get LoRA parameters for a specific layer.
Args:
layer_key: Key of the layer
Returns:
LoraLayerParams for the specified layer, or None if layer has no LoRA modules
"""
return self.layer_params.get(layer_key)

View File

@ -1,8 +1,48 @@
from dataclasses import dataclass
from enum import IntEnum
from typing import Dict, List, Optional
import torch
from .cuda_graph_lora_params import CudaGraphLoraParams
@dataclass
class GroupedGemmParamsOutput:
in_sizes: Optional[torch.Tensor] = None
out_sizes: Optional[torch.Tensor] = None
a_offset: Optional[torch.Tensor] = None
d_offset: Optional[torch.Tensor] = None
d_prime_offset: Optional[torch.Tensor] = None
lda: Optional[torch.Tensor] = None
ldb: Optional[torch.Tensor] = None
ldd: Optional[torch.Tensor] = None
ldb_prime: Optional[torch.Tensor] = None
ldd_prime: Optional[torch.Tensor] = None
splitk_offsets: Optional[torch.Tensor] = None
reordered_input: Optional[torch.Tensor] = None
@dataclass
class GroupedGemmParamsInput:
x: torch.Tensor
output_buffer: torch.Tensor
intermediate_buffer: torch.Tensor
max_lora_size: int
max_rank: int
slot_counts: torch.Tensor
slot_ranks: torch.Tensor
slot_offsets_full: torch.Tensor
b_ptrs: torch.Tensor
b_prime_ptrs: torch.Tensor
sorted_ids: torch.Tensor
output_hidden_sizes: torch.Tensor
output_sizes_offset: torch.Tensor
@property
def slot_offsets(self):
return self.slot_offsets_full[:-1]
class LoraModuleType(IntEnum):
"""Enum class representing different types of modules that can have LoRA adapters.
@ -100,57 +140,385 @@ class LoraLayer(torch.nn.Module):
) -> Optional[torch.Tensor]:
if bool(lora_params):
lora_ranks = []
lora_weight_pointers = []
active_lora_module_ids = []
for module_idx in self.lora_module_types:
module_idx = int(module_idx)
if module_idx in lora_params[layer_idx]:
active_lora_module_ids.append(module_idx)
lora_ranks.append(
lora_params[layer_idx][module_idx]['adapter_size'])
lora_weight_pointers.append(
lora_params[layer_idx][module_idx]['weight_pointers'])
# Check if we're using CUDA Graph mode
use_cuda_graph_mode = lora_params.get('use_cuda_graph_mode', False)
num_seqs = lora_params['num_seqs']
if len(active_lora_module_ids) == 0:
return None
if use_cuda_graph_mode:
return self._forward_cuda_graph_mode(x, lora_params, layer_idx)
else:
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
x,
lora_params['host_request_types'][:num_seqs],
lora_ranks,
lora_weight_pointers,
lora_params['prompt_lens_cpu'][:num_seqs],
self.output_hidden_sizes,
False, # transA
True, # transB
max([r.max() for r in lora_ranks]),
0,
True, # TODO smor- should be lora_params["remove_input_padding"], support in loraOp as well
)
if isinstance(lora_outputs, torch.Tensor):
return lora_outputs
else:
# For multiple LoRA modules, some might not be executed in grouped gemm.
# For those modules not executed, we create zero tensors with matching dimensions.
# Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order.
lora_output = []
for module_idx in self.lora_module_types:
if int(module_idx) in active_lora_module_ids:
lora_output.append(lora_outputs.pop(0))
else:
lora_output.append(
torch.zeros(list(x.shape[:-1]) + [
self.output_hidden_sizes[
self.lora_module_types.index(
module_idx)]
],
dtype=x.dtype,
device=x.device))
lora_output = torch.cat(lora_output, dim=-1)
return lora_output
return self._forward_eager_mode(x, lora_params, layer_idx)
else:
return None
def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
device = input.x.device
bs, input_hidden_size = input.x.shape
shape_2d = (len(self.lora_module_types), input.max_lora_size
) # [num_layer_modules, max_lora_size]
shape_3d = shape_2d + (3, )
sum_out_sizes = sum(self.output_hidden_sizes)
input.output_buffer.fill_(0)
input.intermediate_buffer.fill_(0)
# reorder input
reordered_input = torch.index_select(input.x, 0, input.sorted_ids[:bs])
# a [bs, hidden]
lda = torch.full(shape_2d,
input_hidden_size,
dtype=CudaGraphLoraParams.LD_DTYPE,
device=device)
# b [input_hidden_size, lora_rank]
ldb = lda
# a_prime / d [num_layer_modules, bs, max_rank]
ldd = torch.full(shape_2d,
input.max_rank,
dtype=CudaGraphLoraParams.LD_DTYPE,
device=device)
# b_prime [lora_rank, module_output_size]
ldb_prime = input.slot_ranks.unsqueeze(0).to(
dtype=CudaGraphLoraParams.LD_DTYPE).repeat(shape_2d[0], 1)
# d_prime [bs, sum_of_each_module_output_sizes]
ldd_prime = torch.full(shape_2d,
sum_out_sizes,
dtype=CudaGraphLoraParams.LD_DTYPE,
device=device)
# reordered a [bs, hidden], each module has the same offset
a_offset = input.slot_offsets * input_hidden_size
a_offset = a_offset.unsqueeze(0).repeat(shape_2d[0], 1)
# d [num_layer_modules, bs, max_rank]
d_offset = (input.slot_offsets.unsqueeze(0) + torch.arange(
shape_2d[0], device=device, dtype=CudaGraphLoraParams.PTR_DTYPE).
unsqueeze(1) * bs) * input.max_rank
# d' [bs, sum_of_each_module_output_sizes]
bs_offset = input.slot_offsets.unsqueeze(0) # [1, max_lora_size]
bs_offset = bs_offset * sum_out_sizes
out_offset = input.output_sizes_offset.unsqueeze(
1) # [num_layer_modules, 1]
d_prime_offset = bs_offset + out_offset
# sizes
in_sizes = torch.empty(shape_3d,
dtype=CudaGraphLoraParams.SIZES_DTYPE,
device=device)
out_sizes = torch.empty_like(in_sizes)
slot_counts = input.slot_counts.unsqueeze(0) # [1, max_lora_size]
ranks = input.slot_ranks.unsqueeze(0) # [1, max_lora_size]
output_hidden_sizes = input.output_hidden_sizes.unsqueeze(
1) # [num_layer_modules, 1]
in_sizes[:, :, 0] = slot_counts
in_sizes[:, :, 1] = ranks
in_sizes[:, :, 2] = input_hidden_size
out_sizes[:, :, 0] = slot_counts
out_sizes[:, :, 1] = output_hidden_sizes
out_sizes[:, :, 2] = ranks
# disable unused modules / lora with ptr being zeros
in_sizes *= (input.b_ptrs != 0).unsqueeze(-1)
out_sizes *= (input.b_prime_ptrs != 0).unsqueeze(-1)
# splitk_offsets: [num_layer_modules, max_lora_size]
# splitk offtsets (m * n) for the first grouped gemm with (m, n, k) = (slot_counts, slot_ranks, input_hidden_size)
splitk_offsets = torch.zeros(shape_2d,
dtype=CudaGraphLoraParams.LD_DTYPE,
device=device)
splitk_offsets.view(-1)[1:] = in_sizes.view(-1, 3)[:-1, 0] # = M
splitk_offsets.view(-1)[1:] *= in_sizes.view(-1, 3)[:-1, 1] # *= N
splitk_offsets.view(-1).cumsum_(dim=0)
# add base addresses to offset tensors on GPU
dtype_element_size = input.x.element_size()
a_offset *= dtype_element_size
a_offset += reordered_input.data_ptr()
d_offset *= dtype_element_size
d_offset += input.intermediate_buffer.data_ptr()
d_prime_offset *= dtype_element_size
d_prime_offset += input.output_buffer.data_ptr()
return GroupedGemmParamsOutput(in_sizes=in_sizes,
out_sizes=out_sizes,
a_offset=a_offset,
d_offset=d_offset,
d_prime_offset=d_prime_offset,
lda=lda,
ldb=ldb,
ldd=ldd,
ldb_prime=ldb_prime,
ldd_prime=ldd_prime,
splitk_offsets=splitk_offsets,
reordered_input=reordered_input)
def _prepare_grouped_gemm_buffers_fused(self,
input: GroupedGemmParamsInput):
device = input.x.device
bs, input_hidden_size = input.x.shape
shape_2d = (len(self.lora_module_types), input.max_lora_size
) # [num_layer_modules, max_lora_size]
shape_3d = shape_2d + (3, )
sum_out_sizes = sum(self.output_hidden_sizes)
in_sizes = torch.empty(shape_3d,
dtype=CudaGraphLoraParams.SIZES_DTYPE,
device=device)
out_sizes = torch.empty_like(in_sizes)
a_offset = torch.empty(shape_2d,
dtype=CudaGraphLoraParams.PTR_DTYPE,
device=device)
d_offset = torch.empty_like(a_offset)
d_prime_offset = torch.empty_like(a_offset)
lda = torch.empty(shape_2d,
dtype=CudaGraphLoraParams.LD_DTYPE,
device=device)
ldb = lda
ldd = torch.empty_like(lda)
ldb_prime = torch.empty_like(lda)
ldd_prime = torch.empty_like(lda)
splitk_offsets = torch.empty(shape_2d,
dtype=CudaGraphLoraParams.LD_DTYPE,
device=device)
reordered_input = torch.empty_like(input.x)
torch.ops.trtllm.lora_group_gemm_param_fill_row_reorder_fusion(
# output parameters
in_sizes,
out_sizes,
a_offset,
d_offset,
d_prime_offset,
lda,
ldd,
ldb_prime,
ldd_prime,
splitk_offsets,
reordered_input,
# input parameters
input.max_lora_size,
input.max_rank,
sum_out_sizes,
input_hidden_size,
bs, # batch_size
input.slot_counts,
input.slot_ranks,
input.slot_offsets,
input.output_hidden_sizes,
input.output_sizes_offset,
input.b_ptrs,
input.b_prime_ptrs,
input.x,
input.sorted_ids[:bs],
input.intermediate_buffer,
input.output_buffer,
input.x.dtype)
return GroupedGemmParamsOutput(in_sizes=in_sizes,
out_sizes=out_sizes,
a_offset=a_offset,
d_offset=d_offset,
d_prime_offset=d_prime_offset,
lda=lda,
ldb=ldb,
ldd=ldd,
ldb_prime=ldb_prime,
ldd_prime=ldd_prime,
splitk_offsets=splitk_offsets,
reordered_input=reordered_input)
def _prepare_max_sizes_cpu(self,
cuda_graph_lora_params: CudaGraphLoraParams,
layer_key: CudaGraphLoraParams.LoraLayerKey,
bs: int, input_hidden_size: int):
layer_params = cuda_graph_lora_params.get_layer_params(layer_key)
shape_2d = (len(self.lora_module_types),
cuda_graph_lora_params.max_lora_size
) # [num_layer_modules, max_lora_size]
shape_3d = shape_2d + (3, )
# dummy max sizes, on CPU
host_max_in_sizes = torch.empty(
shape_3d, dtype=CudaGraphLoraParams.SIZES_DTYPE
) # m: batch_size, n: max_lora_rank, k: input_hidden_size
host_max_out_sizes = torch.empty_like(
host_max_in_sizes
) # m: batch_size, n: max_output_hidden_size, k: max_lora_rank
host_max_in_sizes[:, :, 0] = bs
host_max_in_sizes[:, :, 1] = cuda_graph_lora_params.max_rank
host_max_in_sizes[:, :, 2] = input_hidden_size
host_max_out_sizes[:, :, 0] = bs
host_max_out_sizes[:, :, 1] = layer_params.h_output_sizes.unsqueeze(1)
host_max_out_sizes[:, :, 2] = cuda_graph_lora_params.max_rank
return host_max_in_sizes, host_max_out_sizes
def _forward_cuda_graph_mode(
self,
x: torch.Tensor,
lora_params: Dict,
layer_idx: int,
) -> Optional[torch.Tensor]:
"""
Forward pass using CUDA Graph compatible LoRA parameters.
Args:
x: Input tensor
lora_params: CUDA Graph compatible LoRA parameters
layer_idx: Current layer index
Returns:
LoRA output tensor or None
"""
cuda_graph_params: CudaGraphLoraParams = lora_params.get(
'cuda_graph_params')
# Get layer-specific parameters
layer_key = CudaGraphLoraParams.LoraLayerKey(
layer_idx=layer_idx, module_ids=tuple(self.lora_module_types))
if not cuda_graph_params or not cuda_graph_params.layer_info or layer_key not in cuda_graph_params.layer_info:
return None
layer_params = cuda_graph_params.get_layer_params(layer_key)
# Skip layers that don't have LoRA modules
if layer_params is None:
return 0 # Pass-through for layers without LoRA modules
batch_size, hidden_size = x.shape[0], x.shape[-1]
num_layer_modules = len(self.lora_module_types)
max_rank = cuda_graph_params.max_rank
total_output_size = sum(self.output_hidden_sizes)
min_kn = min(
hidden_size, 8, max_rank
) # TODO: hardcode to 8 for now, for alignments in kernels, might have alignment error if rank is less than 8!
output_buffer = torch.empty(batch_size,
total_output_size,
dtype=x.dtype,
device=x.device)
host_max_in_sizes, host_max_out_sizes = self._prepare_max_sizes_cpu(
cuda_graph_params, layer_key, batch_size, hidden_size)
# Intermediate buffer: [num_layer_modules, batch_size, max_rank]
intermediate_buffer = torch.empty(
[num_layer_modules, batch_size, max_rank],
dtype=x.dtype,
device=x.device)
params_fill_input = GroupedGemmParamsInput(
x=x,
output_buffer=output_buffer,
intermediate_buffer=intermediate_buffer,
max_lora_size=cuda_graph_params.max_lora_size,
max_rank=cuda_graph_params.max_rank,
slot_counts=cuda_graph_params.slot_counts,
slot_ranks=cuda_graph_params.slot_ranks,
slot_offsets_full=cuda_graph_params.slot_offsets_full,
b_ptrs=layer_params.d_b_ptrs,
b_prime_ptrs=layer_params.d_b_prime_ptrs,
sorted_ids=cuda_graph_params.sorted_ids,
output_hidden_sizes=layer_params.d_output_sizes,
output_sizes_offset=layer_params.d_output_sizes_offset)
grouped_gemm_params = self._prepare_grouped_gemm_buffers_fused(
params_fill_input)
torch.ops.trtllm.lora_grouped_gemm_cuda_graph(
grouped_gemm_params.in_sizes, grouped_gemm_params.out_sizes,
grouped_gemm_params.a_offset, layer_params.d_b_ptrs,
grouped_gemm_params.d_offset, layer_params.d_b_prime_ptrs,
grouped_gemm_params.d_prime_offset,
cuda_graph_params.get_problem_count(layer_key),
grouped_gemm_params.lda, grouped_gemm_params.ldb,
grouped_gemm_params.ldd, grouped_gemm_params.ldb_prime,
grouped_gemm_params.ldd_prime, host_max_in_sizes,
host_max_out_sizes, grouped_gemm_params.splitk_offsets,
grouped_gemm_params.reordered_input.dtype, min_kn)
# TODO: move to kernel
restored_output = torch.zeros_like(output_buffer)
restored_output.index_copy_(0,
cuda_graph_params.sorted_ids[:batch_size],
output_buffer)
return restored_output
def _forward_eager_mode(
self,
x: torch.Tensor,
lora_params: Dict,
layer_idx: int,
) -> Optional[torch.Tensor]:
"""
Eager-mode forward pass using the original LoRA implementation.
Args:
x: Input tensor
lora_params: LoRA parameters for eager mode
layer_idx: Current layer index
Returns:
LoRA output tensor or None
"""
lora_ranks = []
lora_weight_pointers = []
active_lora_module_ids = []
for module_idx in self.lora_module_types:
module_idx = int(module_idx)
if module_idx in lora_params[layer_idx]:
active_lora_module_ids.append(module_idx)
lora_ranks.append(
lora_params[layer_idx][module_idx]['adapter_size'])
lora_weight_pointers.append(
lora_params[layer_idx][module_idx]['weight_pointers'])
num_seqs = lora_params['num_seqs']
if len(active_lora_module_ids) == 0:
return None
else:
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
x,
lora_params['host_request_types'][:num_seqs],
lora_ranks,
lora_weight_pointers,
lora_params['prompt_lens_cpu'][:num_seqs],
self.output_hidden_sizes,
False, # transA
True, # transB
max([r.max() for r in lora_ranks]),
0,
True, # TODO smor- should be lora_params["remove_input_padding"], support in loraOp as well
)
if isinstance(lora_outputs, torch.Tensor):
return lora_outputs
else:
# For multiple LoRA modules, some might not be executed in grouped gemm.
# For those modules not executed, we create zero tensors with matching dimensions.
# Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order.
lora_output = []
for module_idx in self.lora_module_types:
if int(module_idx) in active_lora_module_ids:
lora_output.append(lora_outputs.pop(0))
else:
lora_output.append(
torch.zeros(list(x.shape[:-1]) + [
self.output_hidden_sizes[
self.lora_module_types.index(module_idx)]
],
dtype=x.dtype,
device=x.device))
lora_output = torch.cat(lora_output, dim=-1)
return lora_output

View File

@ -834,6 +834,8 @@ def create_py_executor_instance(
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules,
lora_config.swap_gate_up_proj_lora_b_weight)
if isinstance(model_engine, PyTorchModelEngine):
model_engine._init_cuda_graph_lora_manager(lora_config)
resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
max_num_sequences)

View File

@ -533,9 +533,6 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
self.py_decoding_iter = 0
self.is_attention_dp_dummy = False
self.is_cuda_graph_dummy = False
self.py_lora_task_layer_module_configs: list[
tensorrt_llm.bindings.internal.runtime.
TaskLayerModuleConfig] | None = None
self.py_kv_transfer_start_time = None
self.py_kv_transfer_timed_out = False

View File

@ -16,6 +16,7 @@ import torch._dynamo.config
import tensorrt_llm.bindings.internal.userbuffers as ub
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
torch_dtype_to_str, trace_func)
from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
MultimodalRuntimeData)
from tensorrt_llm.inputs.registry import (create_input_processor,
@ -44,6 +45,7 @@ from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids
from ..models.modeling_utils import DecoderModelForCausalLM
from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer,
MoeLoadBalancerIterContext)
from ..peft.lora.cuda_graph_lora_manager import CudaGraphLoraManager
from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
get_spec_metadata,
update_spec_config_from_model_config)
@ -62,7 +64,8 @@ from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .llm_request import LlmRequest, get_draft_token_length
from .model_loader import ModelLoader, _construct_checkpoint_loader
from .resource_manager import (BaseResourceManager, KVCacheManager,
ResourceManager, ResourceManagerType)
PeftCacheManager, ResourceManager,
ResourceManagerType)
from .sampler import SampleStateTensors
from .scheduler import ScheduledRequests
@ -449,6 +452,9 @@ class PyTorchModelEngine(ModelEngine):
)
self.cuda_graph_runner = CUDAGraphRunner(cuda_graph_runner_config)
# Initialize CUDA Graph LoRA manager if LoRA is enabled
self.cuda_graph_lora_manager: Optional[CudaGraphLoraManager] = None
# Setup the local cache indirection buffer only once and reuse it.
# This way it can also be used for CUDA graphs.
if self.use_beam_search:
@ -493,6 +499,26 @@ class PyTorchModelEngine(ModelEngine):
dtype=torch_dtype_to_str(self.model.config.torch_dtype),
swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight)
def _init_cuda_graph_lora_manager(self, lora_config: LoraConfig):
"""Initialize CUDA Graph LoRA manager with model configuration."""
# Get model configuration
if self.cuda_graph_runner.enabled:
max_lora_size = lora_config.max_loras or 8 # Default fallback
max_batch_size = self.batch_size # Use engine's max batch size
self.cuda_graph_lora_manager = CudaGraphLoraManager(
max_lora_size=max_lora_size,
max_batch_size=max_batch_size,
max_lora_rank=lora_config.max_lora_rank,
model=self.model,
lora_model_config=self.lora_model_config,
device='cuda')
logger.info(
f"Initialized CUDA Graph LoRA manager, "
f"max {max_lora_size} adapters, max rank {lora_config.max_lora_rank}"
)
def set_guided_decoder(self,
guided_decoder: CapturableGuidedDecoder) -> bool:
if hasattr(self.model, "set_guided_decoder"):
@ -1936,7 +1962,8 @@ class PyTorchModelEngine(ModelEngine):
cache_indirection_buffer: Optional[torch.Tensor] = None,
num_accepted_tokens_device: Optional[torch.Tensor] = None,
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None,
resource_manager: Optional[ResourceManager] = None):
resource_manager: Optional[ResourceManager] = None,
maybe_graph: bool = False):
"""
Prepare inputs for Pytorch Model.
"""
@ -2673,8 +2700,10 @@ class PyTorchModelEngine(ModelEngine):
attn_metadata.prepare()
peft_cache_manager = resource_manager and resource_manager.get_resource_manager(
ResourceManagerType.PEFT_CACHE_MANAGER)
lora_params = self._get_lora_params_from_requests(
scheduled_requests, attn_metadata)
scheduled_requests, attn_metadata, peft_cache_manager, maybe_graph)
attn_all_rank_num_tokens = self._get_all_rank_num_tokens(attn_metadata)
padded_num_tokens, can_run_piecewise_cuda_graph, attn_all_rank_num_tokens = self._get_padding_params(
@ -3142,10 +3171,41 @@ class PyTorchModelEngine(ModelEngine):
'inputs_embeds': None
}, gather_ids if is_spec_decode else None
def _get_lora_params_from_requests(self,
scheduled_requests: ScheduledRequests,
attn_metadata: AttentionMetadata):
def _get_lora_params_from_requests(
self,
scheduled_requests: ScheduledRequests,
attn_metadata: AttentionMetadata,
peft_cache_manager: Optional[PeftCacheManager] = None,
maybe_graph: bool = False):
'''
Get LoRA parameters from scheduled requests.
Uses CUDA Graph compatible mode in decode only batch, otherwise falls back to eager mode.
Returns:
Dictionary containing LoRA parameters, or None if no LoRA requests
'''
use_cuda_graph_mode = self.cuda_graph_lora_manager is not None and maybe_graph
if use_cuda_graph_mode:
return self.cuda_graph_lora_manager.prepare_cuda_graph_lora_params(
scheduled_requests, attn_metadata, peft_cache_manager)
else:
if self.cuda_graph_lora_manager is not None:
self.cuda_graph_lora_manager.adapter_slot_manager.remove_evicted_slots_in_cpp(
peft_cache_manager)
peft_table = peft_cache_manager.get_and_reset_batch_peft_table(
) if peft_cache_manager is not None else None
return peft_table and self._get_eager_lora_params_from_requests(
scheduled_requests, attn_metadata, peft_table)
def _get_eager_lora_params_from_requests(
self, scheduled_requests: ScheduledRequests,
attn_metadata: AttentionMetadata,
peft_table: Dict[int, list[TaskLayerModuleConfig]]):
'''
Eager mode LoRA parameter preparation logic.
lora_params: dict
{
layer_id: dict
@ -3165,10 +3225,12 @@ class PyTorchModelEngine(ModelEngine):
# trace all requests to get the union set of the lora params
for request in request_list:
if request.py_lora_task_layer_module_configs is None:
if request.lora_task_id is None:
continue
for module in request.py_lora_task_layer_module_configs:
layer_module_configs = peft_table[request.lora_task_id]
for module in layer_module_configs:
module_id = module.module_id
layer_id = module.layer_id
@ -3195,7 +3257,7 @@ class PyTorchModelEngine(ModelEngine):
for request in request_list:
# Need to set default values for this case
if request.py_lora_task_layer_module_configs is None:
if request.lora_task_id is None:
for layer_id in lora_params:
for module_id in lora_params[layer_id]:
current_lora_params = lora_params[layer_id][module_id]
@ -3245,7 +3307,8 @@ class PyTorchModelEngine(ModelEngine):
cache_indirection_buffer: Optional[torch.Tensor] = None,
num_accepted_tokens_device: Optional[torch.Tensor] = None,
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None,
resource_manager: Optional[ResourceManager] = None):
resource_manager: Optional[ResourceManager] = None,
maybe_graph: bool = False):
if self.mapping is not None and 'cp_type' in self.mapping.cp_config:
cp_type = self.mapping.cp_config['cp_type']
if CpType.STAR == cp_type:
@ -3258,12 +3321,11 @@ class PyTorchModelEngine(ModelEngine):
raise NotImplementedError(
f"Unsupported cp_type {getattr(cp_type, 'name', cp_type)}.")
return self._prepare_tp_inputs(scheduled_requests, kv_cache_manager,
attn_metadata, spec_metadata,
new_tensors_device,
cache_indirection_buffer,
num_accepted_tokens_device,
req_id_to_old_request, resource_manager)
return self._prepare_tp_inputs(
scheduled_requests, kv_cache_manager, attn_metadata, spec_metadata,
new_tensors_device, cache_indirection_buffer,
num_accepted_tokens_device, req_id_to_old_request, resource_manager,
maybe_graph)
@torch.inference_mode()
@with_model_extra_attrs(lambda self: self.model.extra_attrs)
@ -3347,12 +3409,11 @@ class PyTorchModelEngine(ModelEngine):
spec_metadata = self.spec_metadata
else:
spec_metadata = None
inputs, gather_ids = self._prepare_inputs(
padded_requests, kv_cache_manager, attn_metadata, spec_metadata,
new_tensors_device, cache_indirection_buffer,
num_accepted_tokens_device, req_id_to_old_request,
resource_manager)
resource_manager, can_run_graph)
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
if not can_run_graph:

View File

@ -11,6 +11,7 @@ import torch
import tensorrt_llm
import tensorrt_llm.bindings
from tensorrt_llm._torch.distributed.communicator import Distributed, ReduceOp
from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig
from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig,
PybindMirror)
from tensorrt_llm.lora_helper import LoraConfig
@ -1554,6 +1555,9 @@ class PeftCacheManager(BaseResourceManager):
model_config=ModelConfigPython.from_model_config_cpp(model_config),
cpp_peft_cache_manager=self.impl)
self._batch_peft_table: Optional[Dict[int, list[
TaskLayerModuleConfig]]] = None # task_id -> layer-module-configs mapping for the current batch
def get_lora_manager(self) -> LoraManager:
return self._lora_manager
@ -1600,18 +1604,9 @@ class PeftCacheManager(BaseResourceManager):
for req in context_batch:
self.add_request_peft(req)
py_lora_task_layer_module_configs = self.impl.ensure_batch(
self._batch_peft_table, _ = self.impl.ensure_batch_map_task_id(
context_batch, generation_batch, False)
for req in context_batch:
req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[
req.
py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None
for req in generation_batch:
req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[
req.
py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None
def update_resources(self, scheduled_batch: ScheduledRequests):
pass
@ -1620,3 +1615,12 @@ class PeftCacheManager(BaseResourceManager):
def shutdown(self):
pass
def get_and_reset_batch_peft_table(
self) -> Dict[int, list[TaskLayerModuleConfig]]:
batch_peft_table = self._batch_peft_table
self._batch_peft_table = None
return batch_peft_table
def is_task_cached_device(self, task_id: int) -> bool:
return self.impl.is_task_cached_device(task_id)

View File

@ -59,6 +59,8 @@ def test_register_fake(custom_ops):
# TODO: add fake impl for these ops in follow-up PRs.
to_fix = {
"trtllm::lora_grouped_gemm",
"trtllm::lora_grouped_gemm_cuda_graph",
"trtllm::lora_group_gemm_param_fill_row_reorder_fusion",
"trtllm::mtp_relaxed_acceptance_op",
"trtllm::mtp_update_hidden_states_op",
"trtllm::mtp_prepare_drafter_inputs_op",

View File

@ -1,18 +1,28 @@
import json
import tarfile
import tempfile
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import List, OrderedDict, Type
from typing import List, Optional, OrderedDict, Tuple, Type, Union
import pytest
import torch
from utils.llm_data import llm_models_root
from utils.util import duplicate_list_to_length, flatten_list, similar
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch.peft.lora.cuda_graph_lora_params import \
CudaGraphLoraParams
from tensorrt_llm._torch.peft.lora.layer import (GroupedGemmParamsInput,
GroupedGemmParamsOutput,
LoraLayer)
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi.llm import BaseLLM
from tensorrt_llm.llmapi.llm_args import CudaGraphConfig
from tensorrt_llm.lora_helper import LoraConfig
from .test_utils import DelayedAssert
_RU_LORA_ADAPTER_PROMPTS = [
"Назови главную площадь в центре Москвы.",
"Напиши полное предложение, описывающее, что в музее не хватает женских скульптур. Используй фразу \"не хватает\".",
@ -284,3 +294,233 @@ def create_mock_nemo_lora_checkpoint(
tar.add(config_path, arcname="model_config.yaml")
return nemo_path
@dataclass
class CUDAGraphLoRATestParams:
batch_slot_ids: List[int]
input_hidden_size: int
slot_ranks: List[int]
max_lora_rank: int
output_hidden_sizes: List[int]
layer_module_mask: Optional[Union[torch.Tensor, bool]]
dtype: torch.dtype
seed: int
def __post_init__(self):
assert self.layer_module_mask is None or isinstance(
self.layer_module_mask,
bool) or self.layer_module_mask.shape == (self.module_count,
self.slot_count)
assert all(0 <= idx <= self.slot_count for idx in self.batch_slot_ids)
assert all(0 <= rank <= self.max_lora_rank for rank in self.slot_ranks)
if isinstance(self.layer_module_mask, torch.Tensor):
self.layer_module_mask = self.layer_module_mask.to(dtype=torch.bool)
elif self.layer_module_mask is not None:
self.layer_module_mask = bool(self.layer_module_mask)
else:
self.layer_module_mask = True
@property
def module_count(self):
return len(self.output_hidden_sizes)
@property
def slot_count(self):
return len(self.slot_ranks)
@property
def batch_size(self):
return len(self.batch_slot_ids)
@property
def sum_output_hidden_size(self):
return sum(self.output_hidden_sizes)
def create_grouped_gemm_params_filler_input(
test_params: Optional[CUDAGraphLoRATestParams] = None
) -> Tuple[GroupedGemmParamsInput, LoraLayer]:
if test_params is None:
test_params = CUDAGraphLoRATestParams(
batch_slot_ids=[0, 3, 3, 4, 5, 8],
input_hidden_size=4096,
slot_ranks=[8, 12, 4, 3] * 2,
max_lora_rank=64,
output_hidden_sizes=[4096, 4096],
layer_module_mask=None,
dtype=torch.bfloat16,
seed=42,
)
with torch.random.fork_rng():
torch.manual_seed(test_params.seed)
shape_2d = (test_params.module_count, test_params.slot_count)
x = torch.randn(test_params.batch_size,
test_params.input_hidden_size,
dtype=test_params.dtype,
device="cuda")
output_buffer = torch.randn(test_params.batch_size,
test_params.sum_output_hidden_size,
dtype=test_params.dtype,
device="cuda")
b_ptrs = torch.randint(1,
1000000,
shape_2d,
dtype=CudaGraphLoraParams.PTR_DTYPE)
b_prime_ptrs = torch.randint(1,
1000000,
shape_2d,
dtype=CudaGraphLoraParams.PTR_DTYPE)
b_ptrs *= test_params.layer_module_mask
b_prime_ptrs *= test_params.layer_module_mask
b_ptrs = b_ptrs.to(device="cuda")
b_prime_ptrs = b_prime_ptrs.to(device="cuda")
slot_ranks = torch.tensor(test_params.slot_ranks,
dtype=CudaGraphLoraParams.SIZES_DTYPE,
device="cuda")
intermediate_buffer = torch.randn(test_params.module_count,
test_params.batch_size,
test_params.max_lora_rank,
dtype=test_params.dtype,
device="cuda")
slot_counts = CudaGraphLoraParams.get_slot_counts(
test_params.batch_slot_ids, test_params.slot_count)
slot_offsets_full = CudaGraphLoraParams.get_offset_from_counts(
slot_counts, full=True)
sorted_ids = CudaGraphLoraParams.get_sorted_indices(
test_params.batch_slot_ids)
slot_offsets_full = slot_offsets_full.to(device="cuda",
dtype=torch.int64)
slot_counts = slot_counts.to(device="cuda", dtype=torch.int32)
sorted_ids = sorted_ids.to(device="cuda", dtype=torch.int64)
output_hidden_sizes = torch.tensor(
test_params.output_hidden_sizes,
dtype=CudaGraphLoraParams.SIZES_DTYPE,
device="cuda")
output_sizes_offset = CudaGraphLoraParams.get_offset_from_counts(
output_hidden_sizes).to(dtype=CudaGraphLoraParams.PTR_DTYPE,
device="cuda")
layer = LoraLayer([0] * test_params.module_count,
test_params.output_hidden_sizes)
inputs = GroupedGemmParamsInput(
x=x,
output_buffer=output_buffer,
intermediate_buffer=intermediate_buffer,
max_lora_size=test_params.slot_count,
max_rank=test_params.max_lora_rank,
slot_counts=slot_counts,
slot_ranks=slot_ranks,
slot_offsets_full=slot_offsets_full,
sorted_ids=sorted_ids,
b_ptrs=b_ptrs,
b_prime_ptrs=b_prime_ptrs,
output_hidden_sizes=output_hidden_sizes,
output_sizes_offset=output_sizes_offset,
)
return inputs, layer
def compare_grouped_gemm_params(
params: GroupedGemmParamsOutput,
ref: GroupedGemmParamsOutput,
params_input: GroupedGemmParamsInput,
params_to_store_msg: List[str] | None = ['splitk_offsets'],
params_exclude_msg: List[str] | None = None,
):
assert not (params_to_store_msg and params_exclude_msg)
bs, input_hidden_size = params.reordered_input.shape
asserter = DelayedAssert()
params_dict = asdict(params)
ref_dict = asdict(ref)
if not params_to_store_msg:
params_to_store_msg = set(params_dict.keys())
if params_exclude_msg:
for name in params_exclude_msg:
params_to_store_msg.discard(name)
def get_msg(name: str, v: torch.Tensor, ref_v: torch.Tensor):
is_get_msg = any(p in name or name in p for p in params_to_store_msg)
header = f"\n\n{name=}\n"
return f"{header} {v=}\n {ref_v=}\n diff:\n{v - ref_v}" if is_get_msg else header
for name in params_dict.keys():
v = params_dict[name]
ref_v = ref_dict[name]
if name not in ("reordered_input", "a_offset"):
asserter.add(
v.allclose(ref_v),
get_msg(name, v, ref_v),
)
# Test a_offset separately
offset = params.a_offset - params.reordered_input.data_ptr()
ref_offset = ref.a_offset - ref.reordered_input.data_ptr()
asserter.add(
(offset == ref_offset).all(),
# 'a_offset_fused',
get_msg("a_offset", offset, ref_offset))
# Test reordered_input separately
valid_row = params_input.slot_offsets_full[-1].cpu().item()
valid_rows = params.reordered_input[:valid_row]
ref_valid_rows = ref.reordered_input[:valid_row]
asserter.add(
valid_rows.allclose(ref_valid_rows),
get_msg(f"valid part({valid_row=}, {bs=}) of reordered_input",
valid_rows, ref_valid_rows))
# check intermediate buffer and output buffer are all zeros
asserter.add(
torch.all(params_input.intermediate_buffer == 0),
get_msg("intermediate buffer", params_input.intermediate_buffer, 0))
asserter.add(torch.all(params_input.output_buffer == 0),
get_msg("output buffer", params_input.output_buffer, 0))
if valid_row < bs:
invalid_rows = params.reordered_input[valid_row:]
ref_invalid_rows = ref.reordered_input[valid_row:]
asserter.add(
torch.all(invalid_rows == 0),
get_msg("invalid part of reordered_input", invalid_rows,
ref_invalid_rows))
else:
asserter.add(
True,
f"valid_row is full {valid_row=} v. bs: {params_dict['reordered_input'].shape[0]=}"
)
asserter.assert_all()
def compare_cuda_graph_lora_params_filler(test_params: CUDAGraphLoRATestParams):
grouped_gemm_params_filler_input, layer = create_grouped_gemm_params_filler_input(
test_params)
output_fused = layer._prepare_grouped_gemm_buffers_fused(
grouped_gemm_params_filler_input)
assert torch.all(
grouped_gemm_params_filler_input.intermediate_buffer == 0
), f"intermediate_buffer is not all zeros: {grouped_gemm_params_filler_input.intermediate_buffer}; non zero / zeros: {(grouped_gemm_params_filler_input.intermediate_buffer != 0).sum()} / {grouped_gemm_params_filler_input.intermediate_buffer.numel()}"
assert torch.all(
grouped_gemm_params_filler_input.output_buffer == 0
), f"output_buffer is not all zeros: {grouped_gemm_params_filler_input.output_buffer}; non zero / zeros: {(grouped_gemm_params_filler_input.output_buffer != 0).sum()} / {grouped_gemm_params_filler_input.output_buffer.numel()}"
output_pytorch = layer.prepare_grouped_gemm_buffers(
grouped_gemm_params_filler_input)
compare_grouped_gemm_params(output_fused,
output_pytorch,
grouped_gemm_params_filler_input,
params_to_store_msg=None)
test_lora_with_and_without_cuda_graph = pytest.mark.parametrize(
"cuda_graph_config", [CudaGraphConfig(max_batch_size=10), None])

View File

@ -9,7 +9,8 @@ from tensorrt_llm.sampling_params import SamplingParams
from .lora_test_utils import (
check_llama_7b_multi_lora_from_request_test_harness,
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1)
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1,
test_lora_with_and_without_cuda_graph)
from .test_llm import (_test_llm_capture_request_error, llama_model_path,
llm_get_stats_test_harness,
llm_return_logprobs_test_harness,
@ -46,9 +47,10 @@ def test_llama_7b_lora_tp2():
kv_cache_config=global_kv_cache_config)
@pytest.mark.gpu2
@pytest.mark.skip(reason="https://nvbugs/5682551")
def test_llama_7b_multi_lora_tp2():
@pytest.mark.gpu4
@skip_ray # https://nvbugs/5682551
@test_lora_with_and_without_cuda_graph
def test_llama_7b_multi_lora_tp4(cuda_graph_config):
# For LoRA checkpoints without finetuned embedding and lm_head, we can either:
# (1) specify lora_target_modules, or
# (2) provide a lora_dir to infer the lora_target_modules.
@ -59,21 +61,18 @@ def test_llama_7b_multi_lora_tp2():
check_llama_7b_multi_lora_from_request_test_harness(
LLM,
lora_config=lora_config,
tensor_parallel_size=2,
tensor_parallel_size=4,
kv_cache_config=global_kv_cache_config,
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)
cuda_graph_config=cuda_graph_config)
@skip_ray # https://nvbugs/5727075
@pytest.mark.gpu2
def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None:
@test_lora_with_and_without_cuda_graph
def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1(
cuda_graph_config) -> None:
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1(
LLM,
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)
LLM, cuda_graph_config=cuda_graph_config)
@skip_ray

View File

@ -20,7 +20,8 @@ from tensorrt_llm.sampling_params import SamplingParams
from .lora_test_utils import (
check_llama_7b_multi_lora_from_request_test_harness,
check_llama_7b_multi_unique_lora_adapters_from_request,
create_mock_nemo_lora_checkpoint)
create_mock_nemo_lora_checkpoint, compare_cuda_graph_lora_params_filler,
CUDAGraphLoRATestParams, test_lora_with_and_without_cuda_graph)
from .test_llm import (_test_llm_capture_request_error, get_model_path,
global_kvcache_config, llama_model_path,
llm_get_stats_async_test_harness,
@ -42,6 +43,7 @@ import torch
from peft import LoraConfig as PeftLoraConfig
from peft import get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import replace
# isort: on
@ -283,19 +285,76 @@ def test_embedding_bias_with_torch_sampler_strategies():
)
def test_lora_cuda_graph_params_filling_kernel_special_cases():
torch.cuda.set_device(0)
# test all requests have the same LoRA id case
test_params = CUDAGraphLoRATestParams(
batch_slot_ids=[0] * 10,
input_hidden_size=4096,
slot_ranks=[64] * 10,
max_lora_rank=64,
output_hidden_sizes=[123],
layer_module_mask=None,
dtype=torch.bfloat16,
seed=42,
)
compare_cuda_graph_lora_params_filler(test_params)
# test no LoRA in a batch case
test_params2 = replace(test_params,
batch_slot_ids=[len(test_params.slot_ranks)] * 10)
compare_cuda_graph_lora_params_filler(test_params2)
# test all having three modules case
test_params3 = replace(test_params, output_hidden_sizes=[123, 456, 789])
compare_cuda_graph_lora_params_filler(test_params3)
# test some layer module have invalid weight pointers case
mask = torch.full((test_params3.module_count, test_params3.slot_count),
True,
dtype=torch.bool)
mask[0, 0] = False
mask[1, 7] = False
mask[2, 3] = False
test_params4 = replace(test_params3, layer_module_mask=mask)
compare_cuda_graph_lora_params_filler(test_params4)
# test mixed slot ids case
test_params5 = CUDAGraphLoRATestParams(
batch_slot_ids=[6, 2, 0, 1, 1, 1, 5, 6],
input_hidden_size=512,
slot_ranks=[8, 12, 4] * 2,
max_lora_rank=15,
output_hidden_sizes=[123, 456, 789],
layer_module_mask=None,
dtype=torch.bfloat16,
seed=42,
)
compare_cuda_graph_lora_params_filler(test_params5)
# test mixed slot with invalid weight pointers
mask = torch.full((test_params5.module_count, test_params5.slot_count),
True,
dtype=torch.bool)
mask[1, 3] = False
mask[2, 5] = False
mask[-1, -4] = False
test_params6 = replace(test_params5, layer_module_mask=mask)
compare_cuda_graph_lora_params_filler(test_params6)
def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None:
lora_config = LoraConfig(
lora_dir=[f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"],
max_lora_rank=8,
max_loras=2,
max_cpu_loras=2)
llm = LLM(
model=f"{llm_models_root()}/llama-models/llama-7b-hf",
lora_config=lora_config,
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None,
**llm_kwargs)
if "cuda_graph_config" not in llm_kwargs:
llm_kwargs["cuda_graph_config"] = None
llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf",
lora_config=lora_config,
**llm_kwargs)
try:
prompts = [
"美国的首都在哪里? \n答案:",
@ -318,22 +377,21 @@ def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None:
@skip_gpu_memory_less_than_40gb
@pytest.mark.part0
def test_llama_7b_lora():
llama_7b_lora_from_dir_test_harness()
@test_lora_with_and_without_cuda_graph
def test_llama_7b_lora(cuda_graph_config):
llama_7b_lora_from_dir_test_harness(cuda_graph_config=cuda_graph_config)
@skip_gpu_memory_less_than_40gb
def test_llama_7b_lora_default_modules() -> None:
@test_lora_with_and_without_cuda_graph
def test_llama_7b_lora_default_modules(cuda_graph_config) -> None:
lora_config = LoraConfig(max_lora_rank=64, max_loras=2, max_cpu_loras=2)
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
llm = LLM(
model=hf_model_dir,
lora_config=lora_config,
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)
llm = LLM(model=hf_model_dir,
lora_config=lora_config,
cuda_graph_config=cuda_graph_config)
hf_lora_dir = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"
try:
@ -359,7 +417,8 @@ def test_llama_7b_lora_default_modules() -> None:
def _check_llama_7b_multi_lora_evict_load_new_adapters(
lora_adapter_count_per_call: list[int], max_loras: int,
max_cpu_loras: int, repeat_calls: int, repeats_per_call: int):
max_cpu_loras: int, repeat_calls: int, repeats_per_call: int,
**llm_kwargs):
# For LoRA checkpoints without finetuned embedding and lm_head, we can either:
# (1) specify lora_target_modules, or
# (2) provide a lora_dir to infer the lora_target_modules.
@ -373,15 +432,14 @@ def _check_llama_7b_multi_lora_evict_load_new_adapters(
repeats_per_call,
LLM,
lora_config=lora_config,
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)
**llm_kwargs)
@skip_gpu_memory_less_than_40gb
@skip_ray # https://nvbugs/5682551
@pytest.mark.part3
def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache():
@test_lora_with_and_without_cuda_graph
def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache(cuda_graph_config):
"""Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single
llm.generate call, that's repeated twice.
""" # noqa: D205
@ -390,12 +448,15 @@ def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache():
max_loras=1,
max_cpu_loras=2,
repeat_calls=2,
repeats_per_call=3)
repeats_per_call=3,
cuda_graph_config=cuda_graph_config)
@skip_gpu_memory_less_than_40gb
@pytest.mark.part1
def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache():
@test_lora_with_and_without_cuda_graph
def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache(
cuda_graph_config):
"""Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU
cache size < LoRA CPU cache size.
""" # noqa: D205
@ -404,25 +465,29 @@ def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache():
max_loras=1,
max_cpu_loras=3,
repeat_calls=1,
repeats_per_call=1)
repeats_per_call=1,
cuda_graph_config=cuda_graph_config)
@skip_gpu_memory_less_than_40gb
@pytest.mark.part0
def test_llama_7b_multi_lora_read_from_cache_after_insert():
@test_lora_with_and_without_cuda_graph
def test_llama_7b_multi_lora_read_from_cache_after_insert(cuda_graph_config):
"""Test that loading and then using the same adapters loaded in cache works."""
_check_llama_7b_multi_lora_evict_load_new_adapters(
lora_adapter_count_per_call=[3],
max_loras=3,
max_cpu_loras=3,
repeat_calls=2,
repeats_per_call=1)
repeats_per_call=1,
cuda_graph_config=cuda_graph_config)
@skip_gpu_memory_less_than_40gb
@pytest.mark.part3
@test_lora_with_and_without_cuda_graph
def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_cache(
):
cuda_graph_config):
"""Test eviction, reloading new adapters and reloading previously evicted adapters from the LoRA CPU cache & GPU
cache over multiple llm.generate call repeated twice (two calls with the same requests):
At the end of the 1st llm.generate call:
@ -439,12 +504,14 @@ def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_ca
max_loras=2,
max_cpu_loras=2,
repeat_calls=2,
repeats_per_call=1)
repeats_per_call=1,
cuda_graph_config=cuda_graph_config)
@skip_gpu_memory_less_than_40gb
@pytest.mark.part2
def test_llama_7b_peft_cache_config_affects_peft_cache_size():
@test_lora_with_and_without_cuda_graph
def test_llama_7b_peft_cache_config_affects_peft_cache_size(cuda_graph_config):
"""Tests that LLM arg of peft_cache_config affects the peft cache sizes.
NOTE: The caller can't get the actual LoRA cache sizes, so we instead we
@ -464,9 +531,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size():
lora_config=lora_config_no_cache_size_values,
peft_cache_config=PeftCacheConfig(
host_cache_size=1), # size in bytes
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)
cuda_graph_config=cuda_graph_config)
# Test that too small PeftCacheConfig.device_cache_percent causes failure
with pytest.raises(RuntimeError):
@ -474,15 +539,14 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size():
LLM,
lora_config=lora_config_no_cache_size_values,
peft_cache_config=PeftCacheConfig(device_cache_percent=0.0000001),
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)
cuda_graph_config=cuda_graph_config)
@skip_ray # https://nvbugs/5682551
@skip_gpu_memory_less_than_40gb
@pytest.mark.part1
def test_llama_7b_lora_config_overrides_peft_cache_config():
@test_lora_with_and_without_cuda_graph
def test_llama_7b_lora_config_overrides_peft_cache_config(cuda_graph_config):
"""Tests that cache size args in lora_config LLM arg override the cache size
parameters in peft_cache_config LLM arg.
""" # noqa: D205
@ -496,9 +560,7 @@ def test_llama_7b_lora_config_overrides_peft_cache_config():
peft_cache_config=PeftCacheConfig(
host_cache_size=1, # size in bytes
device_cache_percent=0.0000001),
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)
cuda_graph_config=cuda_graph_config)
# TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high
@ -506,7 +568,8 @@ def test_llama_7b_lora_config_overrides_peft_cache_config():
@pytest.mark.skip(reason="https://nvbugs/5448464")
@skip_gpu_memory_less_than_138gb
@pytest.mark.part1
def test_nemotron_nas_lora() -> None:
@test_lora_with_and_without_cuda_graph
def test_nemotron_nas_lora(cuda_graph_config) -> None:
lora_config = LoraConfig(lora_dir=[
f"{llm_models_root()}/nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1-lora-adapter_r64"
],
@ -518,7 +581,7 @@ def test_nemotron_nas_lora() -> None:
model=
f"{llm_models_root()}/nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1",
lora_config=lora_config,
)
cuda_graph_config=cuda_graph_config)
prompts = [
"Hello, how are you?",
@ -539,7 +602,8 @@ def test_nemotron_nas_lora() -> None:
@skip_gpu_memory_less_than_80gb
@pytest.mark.part0
def test_llama_3_1_8b_fp8_with_bf16_lora() -> None:
@test_lora_with_and_without_cuda_graph
def test_llama_3_1_8b_fp8_with_bf16_lora(cuda_graph_config) -> None:
skip_fp8_pre_ada(use_fp8=True)
model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
lora_dir = f"{llm_models_root()}/lora/llama-3-chinese-8b-instruct-v2-lora"
@ -552,12 +616,9 @@ def test_llama_3_1_8b_fp8_with_bf16_lora() -> None:
max_cpu_loras=2)
lora_req = LoRARequest("lora-chinese", 0, lora_dir)
llm = LLM(
model_dir,
lora_config=lora_config,
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)
llm = LLM(model_dir,
lora_config=lora_config,
cuda_graph_config=cuda_graph_config)
try:
output = llm.generate(prompt,
@ -570,7 +631,8 @@ def test_llama_3_1_8b_fp8_with_bf16_lora() -> None:
@skip_gpu_memory_less_than_80gb
@pytest.mark.part2
def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
@test_lora_with_and_without_cuda_graph
def test_bielik_11b_v2_2_instruct_multi_lora(cuda_graph_config) -> None:
model_dir = f"{llm_models_root()}/Bielik-11B-v2.2-Instruct"
target_modules = ['attn_q', 'attn_k', 'attn_v']
@ -600,12 +662,9 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
max_lora_rank=8,
max_loras=2,
max_cpu_loras=2)
llm = LLM(
model_dir,
lora_config=trtllm_lora_config,
# Disable CUDA graph
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
cuda_graph_config=None)
llm = LLM(model_dir,
lora_config=trtllm_lora_config,
cuda_graph_config=cuda_graph_config)
prompts = [
"Kim był Mikołaj Kopernik i z czego zasłynął?",
@ -624,7 +683,8 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
@pytest.mark.part2
def test_gemma3_1b_instruct_multi_lora() -> None:
@test_lora_with_and_without_cuda_graph
def test_gemma3_1b_instruct_multi_lora(cuda_graph_config) -> None:
model_dir = f"{llm_models_root()}/gemma/gemma-3-1b-it"
target_modules = ['attn_q', 'attn_k', 'attn_v']
@ -662,7 +722,8 @@ def test_gemma3_1b_instruct_multi_lora() -> None:
)
llm = LLM(model_dir,
lora_config=trtllm_lora_config,
kv_cache_config=kv_cache_config)
kv_cache_config=kv_cache_config,
cuda_graph_config=cuda_graph_config)
prompts = [
"Is it ok to fill diesel in a petrol car?",
@ -745,7 +806,8 @@ def test_nemo_lora_unsupported_modules_validation(tmp_path):
@force_ampere
@pytest.mark.part1
def test_gqa_nemo_lora(tmp_path):
@test_lora_with_and_without_cuda_graph
def test_gqa_nemo_lora(tmp_path, cuda_graph_config):
"""
Test NeMo-format LoRA checkpoint loading and GQA support in TinyLlama.
@ -790,6 +852,7 @@ def test_gqa_nemo_lora(tmp_path):
model=model_path,
lora_config=lora_config,
kv_cache_config=global_kvcache_config,
cuda_graph_config=cuda_graph_config,
)
try:

View File

@ -32,3 +32,33 @@ def test_generate_api_docs_as_docstring():
doc = generate_api_docs_as_docstring(LlmArgs)
assert ":tag:`beta`" in doc, "the label is not generated"
print(doc)
class DelayedAssert:
def __init__(self, store_stack: bool = False):
self.assertions = []
self.store_stack = store_stack
def add(self, result: bool, msg: str):
import traceback
self.assertions.append(
(bool(result), str(msg), traceback.format_stack()))
def get_msg(self):
ret = ['Some assertions failed:']
for result, msg, stack in self.assertions:
ret.append('\n'.join([
f'Assert result: {result}', msg,
''.join(stack) if self.store_stack else ''
]))
ret = '\n-----------------------------------------\n'.join(ret)
ret = 'Some assertions failed:\n' + ret
return ret
def clear(self):
self.assertions.clear()
def assert_all(self):
assert all(ret[0] for ret in self.assertions), self.get_msg()
self.clear()