mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[https://nvbugs/5322131][feat] Multi-LoRA serving with CUDA Graph (#8279)
Signed-off-by: Jiayu Chang <jiayuc@nvidia.com>
This commit is contained in:
parent
cdb9ffd0ab
commit
1dc49b266e
@ -57,7 +57,10 @@ class BasePeftCacheManager
|
||||
public:
|
||||
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
|
||||
using RequestVector = std::vector<LlmRequestPtr>;
|
||||
using PeftTable = std::map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
|
||||
using PeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
|
||||
using TaskPeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
|
||||
using TaskIdToReqIds = std::unordered_map<uint64_t, std::vector<uint64_t>>;
|
||||
using EnsureBatchTaskResult = std::tuple<TaskPeftTable, TaskIdToReqIds>;
|
||||
|
||||
virtual ~BasePeftCacheManager() = default;
|
||||
|
||||
@ -99,6 +102,8 @@ public:
|
||||
class PeftCacheManager : public BasePeftCacheManager
|
||||
{
|
||||
public:
|
||||
using EnsureBatchTaskResult = BasePeftCacheManager::EnsureBatchTaskResult;
|
||||
|
||||
PeftCacheManager(PeftCacheManagerConfig const& config, runtime::ModelConfig const& modelConfig,
|
||||
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
|
||||
|
||||
@ -109,12 +114,17 @@ public:
|
||||
PeftTable ensureBatch(RequestVector const& contextRequests, RequestVector const& generationRequests,
|
||||
bool resetGpuCache = false) override;
|
||||
|
||||
EnsureBatchTaskResult ensureBatchMapTaskId(
|
||||
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache = false);
|
||||
|
||||
[[nodiscard]] bool isTaskCached(uint64_t taskId) const;
|
||||
|
||||
[[nodiscard]] bool isTaskDone(uint64_t taskId) const;
|
||||
|
||||
[[nodiscard]] bool isTaskDoneDevice(uint64_t taskId) const;
|
||||
|
||||
[[nodiscard]] bool isTaskCachedDevice(uint64_t const taskId) const;
|
||||
|
||||
void resetDeviceCache() override;
|
||||
|
||||
void markRequestDone(LlmRequest const& llmReq, bool pause = false) override;
|
||||
@ -159,7 +169,7 @@ private:
|
||||
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToReqIds;
|
||||
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> mTaskIdToPausedReqIds;
|
||||
|
||||
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>> getTaskMaps(
|
||||
std::tuple<std::unordered_map<uint64_t, std::future<void>>, TaskIdToReqIds> getTaskMaps(
|
||||
RequestVector const& contextRequests, RequestVector const& generationRequests);
|
||||
|
||||
runtime::ModelConfig mModelConfig;
|
||||
|
||||
@ -373,11 +373,11 @@ void PeftCacheManager::addRequestPeft(std::shared_ptr<LlmRequest> llmRequest, bo
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
std::tuple<std::map<uint64_t, std::future<void>>, std::map<uint64_t, std::vector<uint64_t>>>
|
||||
std::tuple<std::unordered_map<uint64_t, std::future<void>>, BasePeftCacheManager::TaskIdToReqIds>
|
||||
PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVector const& generationRequests)
|
||||
{
|
||||
std::map<uint64_t, std::vector<uint64_t>> taskIdToReqIds;
|
||||
std::map<uint64_t, std::future<void>> taskIdToFuture;
|
||||
TaskIdToReqIds taskIdToReqIds;
|
||||
std::unordered_map<uint64_t, std::future<void>> taskIdToFuture;
|
||||
std::lock_guard<std::mutex> futuresLock(mPutFuturesMutex);
|
||||
for (auto const& requests : {contextRequests, generationRequests})
|
||||
{
|
||||
@ -415,7 +415,7 @@ PeftCacheManager::getTaskMaps(RequestVector const& contextRequests, RequestVecto
|
||||
return {std::move(taskIdToFuture), taskIdToReqIds};
|
||||
}
|
||||
|
||||
PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
|
||||
PeftCacheManager::EnsureBatchTaskResult PeftCacheManager::ensureBatchMapTaskId(
|
||||
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
@ -426,7 +426,7 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
|
||||
auto [taskIdToFuture_, taskIdToReqIds] = getTaskMaps(contextRequests, generationRequests);
|
||||
auto taskIdToFuture = std::move(taskIdToFuture_); // captured structured bindings are a C++20 extension
|
||||
|
||||
std::map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
|
||||
std::unordered_map<uint64_t, std::future<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>> ensureFutures;
|
||||
for (auto& [taskId, taskFuture] : taskIdToFuture)
|
||||
{
|
||||
auto fn = [&taskIdToFuture, taskId = taskId, this]() -> std::vector<runtime::LoraCache::TaskLayerModuleConfig>
|
||||
@ -457,18 +457,31 @@ PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
|
||||
ensureFutures.try_emplace(taskId, std::move(f));
|
||||
}
|
||||
|
||||
PeftTable peftTable{};
|
||||
TaskPeftTable peftTable{};
|
||||
for (auto const& [taskId, reqIds] : taskIdToReqIds)
|
||||
{
|
||||
auto&& f = ensureFutures.at(taskId);
|
||||
auto const values = f.get();
|
||||
for (auto const& reqId : reqIds)
|
||||
{
|
||||
peftTable.try_emplace(reqId, values);
|
||||
}
|
||||
peftTable.try_emplace(taskId, values);
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
return peftTable;
|
||||
return {std::move(peftTable), std::move(taskIdToReqIds)};
|
||||
}
|
||||
|
||||
PeftCacheManager::PeftTable PeftCacheManager::ensureBatch(
|
||||
RequestVector const& contextRequests, RequestVector const& generationRequests, bool resetGpuCache)
|
||||
{
|
||||
auto [taskTable, taskIdToReqIds] = ensureBatchMapTaskId(contextRequests, generationRequests, resetGpuCache);
|
||||
PeftTable requestTable{};
|
||||
for (auto const& [taskId, values] : taskTable)
|
||||
{
|
||||
auto const& reqIds = taskIdToReqIds.at(taskId);
|
||||
for (auto const reqId : reqIds)
|
||||
{
|
||||
requestTable.try_emplace(reqId, values);
|
||||
}
|
||||
}
|
||||
return requestTable;
|
||||
}
|
||||
|
||||
bool PeftCacheManager::isTaskCached(uint64_t taskId) const
|
||||
@ -486,6 +499,11 @@ bool PeftCacheManager::isTaskDoneDevice(uint64_t taskId) const
|
||||
return mDeviceLoraCache->isDone(taskId);
|
||||
}
|
||||
|
||||
bool PeftCacheManager::isTaskCachedDevice(uint64_t const taskId) const
|
||||
{
|
||||
return mDeviceLoraCache->has(taskId);
|
||||
}
|
||||
|
||||
void PeftCacheManager::updateTaskState(uint64_t taskId, uint64_t reqId, bool terminate, bool pause)
|
||||
{
|
||||
if (!terminate)
|
||||
@ -645,3 +663,5 @@ SizeType32 NoOpPeftCacheManager::determineNumPages(std::shared_ptr<LlmRequest> l
|
||||
return 0;
|
||||
}
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
// TODO: merge C++ LoRA caching status with Py Slot manager
|
||||
|
||||
372
cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.cu
Normal file
372
cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.cu
Normal file
@ -0,0 +1,372 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cuda_graph_grouped_gemm.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/device/gemm_grouped.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
||||
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h"
|
||||
#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h"
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
/**
|
||||
* Template for CUDA Graph compatible grouped GEMM that directly uses GPU tensors
|
||||
*/
|
||||
template <int M1, int N1, int K1, int M2, int N2, int K2, typename cutlassType, int kAlignmentAB, int kAlignmentC,
|
||||
int kStages>
|
||||
void cudaGraphGroupedGemmTemplate(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
|
||||
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
|
||||
cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream)
|
||||
{
|
||||
using ElementA = cutlassType;
|
||||
using ElementB = cutlassType;
|
||||
using ElementOutput = cutlassType;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<ElementA, LayoutA,
|
||||
cutlass::ComplexTransform::kNone, kAlignmentAB, ElementB, LayoutB, cutlass::ComplexTransform::kNone,
|
||||
kAlignmentAB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<M1, N1, K1>, cutlass::gemm::GemmShape<M2, N2, K2>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<ElementOutput, kAlignmentC, ElementAccumulator,
|
||||
ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, kStages,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>::GemmKernel;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
typename Gemm::EpilogueOutputOp::Params epilogueOp(alpha, beta);
|
||||
|
||||
auto ptrA = reinterpret_cast<ElementA**>(ptrAGpu);
|
||||
auto ptrB = reinterpret_cast<ElementB**>(ptrBGpu);
|
||||
auto ptrC = reinterpret_cast<ElementOutput**>(ptrCGpu);
|
||||
auto ptrD = reinterpret_cast<ElementOutput**>(ptrDGpu);
|
||||
|
||||
Gemm gemmOp;
|
||||
|
||||
int threadblockCount = Gemm::sufficient(nullptr, problemCount);
|
||||
|
||||
typename Gemm::Arguments args(problemSizesPtr, // GPU problem sizes
|
||||
problemCount, // Problem count
|
||||
threadblockCount, // Threadblock count
|
||||
epilogueOp, // Epilogue operation
|
||||
ptrA, // GPU pointer array A
|
||||
ptrB, // GPU pointer array B
|
||||
ptrC, // GPU pointer array C (can be nullptr)
|
||||
ptrD, // GPU pointer array D
|
||||
ldaGpu, // Precomputed leading dimension A (on GPU)
|
||||
ldbGpu, // Precomputed leading dimension B (on GPU)
|
||||
ldcGpu, // Precomputed leading dimension C (on GPU)
|
||||
lddGpu, // Precomputed leading dimension D (on GPU)
|
||||
hostMaxProblemSizesPtr);
|
||||
|
||||
static_assert(Gemm::BaseKernel::ProblemVisitor::kRequiresPrecomputation == false,
|
||||
"Grouped GEMM with CUDA Graph cannot use precompution.");
|
||||
{
|
||||
cutlass::Status status = gemmOp.can_implement(args);
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
|
||||
"Grouped GEMM cannot be implemented with the given arguments, Error: %s",
|
||||
cutlass::cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
at::Tensor workspace;
|
||||
void* gemmWorkspace = nullptr;
|
||||
size_t const requiredWorkspace = gemmOp.get_workspace_size(args);
|
||||
if (requiredWorkspace > 0)
|
||||
{
|
||||
auto const workspaceTensorOptions = at::TensorOptions().dtype(at::kByte).device(at::kCUDA);
|
||||
workspace = at::empty({static_cast<int64_t>(requiredWorkspace)}, workspaceTensorOptions);
|
||||
gemmWorkspace = workspace.data_ptr();
|
||||
}
|
||||
|
||||
cutlass::Status status = gemmOp.initialize(args, gemmWorkspace);
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to initialize grouped GEMM");
|
||||
|
||||
status = gemmOp.run(stream);
|
||||
sync_check_cuda_error(stream);
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess, "Failed to execute grouped GEMM");
|
||||
}
|
||||
|
||||
template <int M1, int N1, int K1, int M2, int N2, int K2, int kAlignmentAB, int kAlignmentC, int kStages>
|
||||
void cudaGraphGroupedGemmType(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
|
||||
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
|
||||
nvinfer1::DataType dataType, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream)
|
||||
{
|
||||
if (dataType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
cudaGraphGroupedGemmTemplate<M1, N1, K1, M2, N2, K2, cutlass::half_t, kAlignmentAB, kAlignmentC, kStages>(
|
||||
problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu,
|
||||
hostMaxProblemSizesPtr, stream);
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (dataType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
cudaGraphGroupedGemmTemplate<M1, N1, K1, M2, N2, K2, cutlass::bfloat16_t, kAlignmentAB, kAlignmentC, kStages>(
|
||||
problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu,
|
||||
hostMaxProblemSizesPtr, stream);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported data type for CUDA Graph grouped GEMM");
|
||||
}
|
||||
}
|
||||
|
||||
void cudaGraphGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, void** ptrBGpu,
|
||||
void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, bool isLoraIn,
|
||||
nvinfer1::DataType dataType, int minKN, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream)
|
||||
{
|
||||
if (isLoraIn)
|
||||
{
|
||||
if (minKN >= 8)
|
||||
{
|
||||
cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
|
||||
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
|
||||
}
|
||||
else if (minKN >= 4)
|
||||
{
|
||||
cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 4, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
|
||||
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
|
||||
}
|
||||
else if (minKN >= 2)
|
||||
{
|
||||
cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 2, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
|
||||
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
|
||||
}
|
||||
else if (minKN >= 1)
|
||||
{
|
||||
cudaGraphGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 1, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
|
||||
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (minKN >= 8)
|
||||
{
|
||||
cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
|
||||
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
|
||||
}
|
||||
else if (minKN >= 4)
|
||||
{
|
||||
cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 4, 8, 4>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
|
||||
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
|
||||
}
|
||||
else if (minKN >= 2)
|
||||
{
|
||||
cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 2, 8, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
|
||||
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
cudaGraphGroupedGemmType<32, 128, 32, 32, 32, 32, 1, 8, 2>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu,
|
||||
ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, hostMaxProblemSizesPtr, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Template for CUDA Graph compatible split-K grouped GEMM
|
||||
*/
|
||||
template <int M1, int N1, int K1, int M2, int N2, int K2, typename cutlassType, int kAlignmentAB, int kAlignmentC,
|
||||
int kStages>
|
||||
void cudaGraphSplitKGroupedGemmTemplate(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
|
||||
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
|
||||
int splitKSlices, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, int64_t* splitKOffsetsGpu, cudaStream_t stream)
|
||||
{
|
||||
using ElementA = cutlassType;
|
||||
using ElementB = cutlassType;
|
||||
using ElementOutput = cutlassType;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultSplitkGemmGrouped<ElementA, LayoutA,
|
||||
cutlass::ComplexTransform::kNone, kAlignmentAB, ElementB, LayoutB, cutlass::ComplexTransform::kNone,
|
||||
kAlignmentAB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
|
||||
cutlass::gemm::GemmShape<M1, N1, K1>, cutlass::gemm::GemmShape<M2, N2, K2>, cutlass::gemm::GemmShape<16, 8, 16>,
|
||||
cutlass::epilogue::thread::LinearCombination<ElementOutput, kAlignmentC, ElementAccumulator,
|
||||
ElementAccumulator>,
|
||||
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, kStages,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>::GemmKernel;
|
||||
|
||||
using Gemm = cutlass::gemm::device::SplitkGemmGrouped<GemmKernel>;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
typename Gemm::EpilogueOutputOp::Params epilogueOp(alpha, beta);
|
||||
|
||||
auto ptrA = reinterpret_cast<ElementA**>(ptrAGpu);
|
||||
auto ptrB = reinterpret_cast<ElementB**>(ptrBGpu);
|
||||
auto ptrC = reinterpret_cast<ElementOutput**>(ptrCGpu);
|
||||
auto ptrD = reinterpret_cast<ElementOutput**>(ptrDGpu);
|
||||
|
||||
Gemm gemmOp;
|
||||
|
||||
int threadblockCount = Gemm::sufficient(nullptr, problemCount);
|
||||
|
||||
// Setup arguments for split-K grouped GEMM - using precomputed leading dimensions from GPU tensors
|
||||
typename Gemm::Arguments args(problemSizesPtr, // GPU problem sizes
|
||||
problemCount, // Problem count
|
||||
threadblockCount, // Threadblock count
|
||||
epilogueOp, // Epilogue operation
|
||||
ptrA, // GPU pointer array A
|
||||
ptrB, // GPU pointer array B
|
||||
ptrC, // GPU pointer array C
|
||||
ptrD, // GPU pointer array D
|
||||
ldaGpu, // Precomputed leading dimension A (on GPU)
|
||||
ldbGpu, // Precomputed leading dimension B (on GPU)
|
||||
ldcGpu, // Precomputed leading dimension C (on GPU)
|
||||
lddGpu, // Precomputed leading dimension D (on GPU)
|
||||
hostMaxProblemSizesPtr, // Host problem sizes
|
||||
splitKSlices, // Split-K factor
|
||||
splitKOffsetsGpu);
|
||||
|
||||
{
|
||||
cutlass::Status status = gemmOp.can_implement(args);
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
|
||||
"Split-K grouped GEMM cannot be implemented with the given arguments. Problem count: %d, Split-K slices: "
|
||||
"%d, Error: %s",
|
||||
problemCount, splitKSlices, cutlass::cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
at::Tensor workspace;
|
||||
void* gemmWorkspace = nullptr;
|
||||
size_t const requiredWorkspace = gemmOp.get_workspace_size(args);
|
||||
if (requiredWorkspace > 0)
|
||||
{
|
||||
workspace = at::empty(
|
||||
{static_cast<int64_t>(requiredWorkspace)}, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
gemmWorkspace = workspace.data_ptr();
|
||||
}
|
||||
|
||||
cutlass::Status status = gemmOp.initialize(args, gemmWorkspace);
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
|
||||
"Failed to initialize split-K grouped GEMM. Problem count: %d, Split-K slices: %d, Error: %s", problemCount,
|
||||
splitKSlices, cutlass::cutlassGetStatusString(status));
|
||||
|
||||
status = gemmOp.run(stream);
|
||||
sync_check_cuda_error(stream);
|
||||
TLLM_CHECK_WITH_INFO(status == cutlass::Status::kSuccess,
|
||||
"Failed to execute split-K grouped GEMM. Problem count: %d, Split-K slices: %d, Error: %s", problemCount,
|
||||
splitKSlices, cutlass::cutlassGetStatusString(status));
|
||||
}
|
||||
|
||||
template <int M1, int N1, int K1, int M2, int N2, int K2, int kAlignmentAB, int kAlignmentC, int kStages>
|
||||
void cudaGraphSplitKGroupedGemmType(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
|
||||
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
|
||||
nvinfer1::DataType dataType, int splitKSlices, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr,
|
||||
int64_t* splitKOffsetsGpu, cudaStream_t stream)
|
||||
{
|
||||
if (dataType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
cudaGraphSplitKGroupedGemmTemplate<M1, N1, K1, M2, N2, K2, cutlass::half_t, kAlignmentAB, kAlignmentC, kStages>(
|
||||
problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu,
|
||||
splitKSlices, hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
|
||||
}
|
||||
#ifdef ENABLE_BF16
|
||||
else if (dataType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
cudaGraphSplitKGroupedGemmTemplate<M1, N1, K1, M2, N2, K2, cutlass::bfloat16_t, kAlignmentAB, kAlignmentC,
|
||||
kStages>(problemSizesPtr, problemCount, ptrAGpu, ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu,
|
||||
splitKSlices, hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported data type for CUDA Graph split-K grouped GEMM");
|
||||
}
|
||||
}
|
||||
|
||||
void cudaGraphSplitKGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
|
||||
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
|
||||
bool isLoraIn, nvinfer1::DataType dataType, int splitKSlices, int minKN,
|
||||
cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, int64_t* splitKOffsetsGpu, cudaStream_t stream)
|
||||
{
|
||||
if (isLoraIn)
|
||||
{
|
||||
if (minKN >= 8)
|
||||
{
|
||||
cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu,
|
||||
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
|
||||
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
|
||||
}
|
||||
else if (minKN >= 4)
|
||||
{
|
||||
cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 4, 4>(problemSizesPtr, problemCount, ptrAGpu,
|
||||
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
|
||||
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
|
||||
}
|
||||
else if (minKN >= 2)
|
||||
{
|
||||
cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 2, 2>(problemSizesPtr, problemCount, ptrAGpu,
|
||||
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
|
||||
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
|
||||
}
|
||||
else if (minKN >= 1)
|
||||
{
|
||||
cudaGraphSplitKGroupedGemmType<16, 32, 64, 16, 32, 64, 8, 1, 2>(problemSizesPtr, problemCount, ptrAGpu,
|
||||
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
|
||||
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (minKN >= 8)
|
||||
{
|
||||
cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 8, 8, 4>(problemSizesPtr, problemCount, ptrAGpu,
|
||||
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
|
||||
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
|
||||
}
|
||||
else if (minKN >= 4)
|
||||
{
|
||||
cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 4, 8, 4>(problemSizesPtr, problemCount, ptrAGpu,
|
||||
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
|
||||
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
|
||||
}
|
||||
else if (minKN >= 2)
|
||||
{
|
||||
cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 2, 8, 2>(problemSizesPtr, problemCount, ptrAGpu,
|
||||
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
|
||||
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
cudaGraphSplitKGroupedGemmType<32, 128, 32, 32, 32, 32, 1, 8, 2>(problemSizesPtr, problemCount, ptrAGpu,
|
||||
ptrBGpu, ptrCGpu, ptrDGpu, ldaGpu, ldbGpu, ldcGpu, lddGpu, dataType, splitKSlices,
|
||||
hostMaxProblemSizesPtr, splitKOffsetsGpu, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
63
cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.h
Normal file
63
cpp/tensorrt_llm/kernels/cuda_graph_grouped_gemm.h
Normal file
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
#include <NvInferRuntime.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
/**
|
||||
* @brief CUDA Graph compatible wrapper for grouped GEMM operations.
|
||||
*
|
||||
* This function accepts GPU pointers directly without any workspace for parameters,
|
||||
* making it fully compatible with CUDA Graph capture and replay.
|
||||
*
|
||||
* @param problemSizesPtr GPU pointer to array of cutlass::gemm::GemmCoord
|
||||
* @param problemCount Number of GEMM problems
|
||||
* @param ptrAGpu GPU pointer to array of A matrix pointers
|
||||
* @param ptrBGpu GPU pointer to array of B matrix pointers
|
||||
* @param ptrCGpu GPU pointer to array of C matrix pointers (can be nullptr)
|
||||
* @param ptrDGpu GPU pointer to array of D matrix pointers
|
||||
* @param isLoraIn Whether this is for LoRA input transformation
|
||||
* @param dataType Data type of the matrices
|
||||
* @param minKN Minimum K*N value for kernel selection
|
||||
* @param stream CUDA stream
|
||||
*/
|
||||
void cudaGraphGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu, void** ptrBGpu,
|
||||
void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu, bool isLoraIn,
|
||||
nvinfer1::DataType dataType, int minKN, cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, cudaStream_t stream);
|
||||
|
||||
/**
|
||||
* @brief CUDA Graph compatible wrapper for split-K grouped GEMM operations.
|
||||
*
|
||||
* Similar to cudaGraphGroupedGemm but uses split-K algorithm for better
|
||||
* performance with certain problem sizes. No parameter workspace needed.
|
||||
*/
|
||||
void cudaGraphSplitKGroupedGemm(cutlass::gemm::GemmCoord* problemSizesPtr, int problemCount, void** ptrAGpu,
|
||||
void** ptrBGpu, void** ptrCGpu, void** ptrDGpu, int64_t* ldaGpu, int64_t* ldbGpu, int64_t* ldcGpu, int64_t* lddGpu,
|
||||
bool isLoraIn, nvinfer1::DataType dataType, int splitKSlices, int minKN,
|
||||
cutlass::gemm::GemmCoord* hostMaxProblemSizesPtr, int64_t* splitKOffsetsGpu, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
@ -299,7 +299,7 @@ int LoraImpl::run(int64_t numTokens, int64_t numReqs, void const* input, int32_t
|
||||
+ (loraModuleIdx * numTokens * mMaxLowRank + handled_token_num * mMaxLowRank) * typeSize));
|
||||
|
||||
auto const N2 = mOutHiddenSizes[loraModuleIdx];
|
||||
cutlass::gemm::GemmCoord problem_2(M, N2, N);
|
||||
cutlass::gemm::GemmCoord problem_2(M, N2, N); // token_num, module_output_size, lora_rank
|
||||
problem_sizes_2.push_back(problem_2);
|
||||
ptrA_2.push_back(static_cast<void*>(static_cast<char*>(lowRankWorkSpace)
|
||||
+ (loraModuleIdx * numTokens * mMaxLowRank + handled_token_num * mMaxLowRank) * typeSize));
|
||||
|
||||
@ -0,0 +1,408 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "loraGroupGEMMParamFillRowReorderFusion.h"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels
|
||||
{
|
||||
namespace
|
||||
{
|
||||
|
||||
template <class T>
|
||||
__forceinline__ constexpr T (&as_singleton_array(T& obj))[1]
|
||||
{
|
||||
return reinterpret_cast<T(&)[1]>(obj);
|
||||
}
|
||||
|
||||
enum ParamIndex
|
||||
{
|
||||
IN_SIZES_INDEX,
|
||||
OUT_SIZES_INDEX,
|
||||
LDA_INDEX,
|
||||
LDD_INDEX,
|
||||
LDB_PRIME_INDEX,
|
||||
LDD_PRIME_INDEX,
|
||||
A_PTRS_INDEX,
|
||||
D_PTRS_INDEX,
|
||||
D_PRIME_PTRS_INDEX,
|
||||
SPLITK_OFFSETS_INDEX,
|
||||
PARAM_COUNT,
|
||||
};
|
||||
|
||||
int constexpr VECTOR_LOAD_WIDTH = 16;
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* Fused kernel for LoRA group GEMM parameter filling, row gather and zero fillings.
|
||||
* Needs to be called with at least PARAM_COUNT blocks. And total number of threads need to be enough to reorder `input`
|
||||
* and fill zeros for intermediate and output buffers. Specifically, (n_threads >= max(input_bytes_to_reorder,
|
||||
* intermediate_bytes_to_fill, output_bytes_to_fill) / VECTOR_LOAD_WIDTH)
|
||||
*
|
||||
* Template parameters:
|
||||
* - BlockDim: Number of threads per block (1D, >= max_lora_count * module_count, >= 256, divisible by 32)
|
||||
* - MODULE_COUNT: Number of modules per layer
|
||||
*/
|
||||
template <int BlockDim, int MODULE_COUNT>
|
||||
__global__ void loraGroupGEMMParamFillRowReorderFusionKernel(
|
||||
// Output parameters
|
||||
int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs, int64_t* d_ptrs, int64_t* d_prime_ptrs, int64_t* lda,
|
||||
int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime, int64_t* splitk_offsets, uint8_t* reordered_input,
|
||||
// Input parameters
|
||||
int32_t max_lora_count, int32_t max_lora_rank, int32_t sum_output_hidden_size, int32_t input_hidden_size,
|
||||
int64_t dtype_element_size, int64_t batch_size, int64_t a_base, int64_t d_base, int64_t d_prime_base,
|
||||
int32_t const* slot_counts, int32_t const* slot_ranks, int64_t const* slot_offsets, int32_t const* module_out_sizes,
|
||||
int64_t const* module_out_prefix, int64_t const* b_ptrs, int64_t const* b_prime_ptrs, uint8_t const* input,
|
||||
int64_t const* sorted_ids)
|
||||
{
|
||||
int const linearIdx = threadIdx.x;
|
||||
int const blockLinearIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
int constexpr THREADS_PER_BLOCK = BlockDim;
|
||||
|
||||
// Calculate lora_id and module_id from linearIdx
|
||||
int const lora_id = linearIdx % max_lora_count;
|
||||
int const module_id = linearIdx / max_lora_count;
|
||||
|
||||
using BlockLoad = cub::BlockLoad<int32_t, BlockDim, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStore = cub::BlockStore<int32_t, BlockDim, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
using BlockLoad64 = cub::BlockLoad<int64_t, BlockDim, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStore64 = cub::BlockStore<int64_t, BlockDim, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
using BlockScan = cub::BlockScan<int64_t, BlockDim, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
using BlockStore3 = cub::BlockStore<int32_t, BlockDim, 3, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
|
||||
__shared__ union
|
||||
{
|
||||
typename BlockStore3::TempStorage storage3;
|
||||
typename BlockScan::TempStorage scan;
|
||||
} large_shared;
|
||||
|
||||
__shared__ int32_t
|
||||
row_count_to_gather; // rows > row_count_to_gather belong to non-LoRA requests, write zero directly
|
||||
|
||||
switch (blockLinearIdx)
|
||||
{
|
||||
case IN_SIZES_INDEX:
|
||||
{
|
||||
int32_t slot_count = slot_counts[lora_id];
|
||||
int32_t rank = slot_ranks[lora_id];
|
||||
int64_t b_ptr = b_ptrs[linearIdx % (max_lora_count * MODULE_COUNT)];
|
||||
int32_t row[3] = {0};
|
||||
if (b_ptr != 0)
|
||||
{
|
||||
row[0] = slot_count;
|
||||
row[1] = rank;
|
||||
row[2] = input_hidden_size;
|
||||
}
|
||||
BlockStore3(large_shared.storage3).Store(in_sizes, row, max_lora_count * MODULE_COUNT * 3);
|
||||
}
|
||||
break;
|
||||
case OUT_SIZES_INDEX:
|
||||
{
|
||||
int32_t slot_count = slot_counts[lora_id];
|
||||
int32_t output_hidden_size = module_out_sizes[module_id];
|
||||
int32_t rank = slot_ranks[lora_id];
|
||||
int64_t b_ptr = b_ptrs[linearIdx % (max_lora_count * MODULE_COUNT)];
|
||||
int32_t row[3] = {0};
|
||||
if (b_ptr != 0)
|
||||
{
|
||||
row[0] = slot_count;
|
||||
row[1] = output_hidden_size;
|
||||
row[2] = rank;
|
||||
}
|
||||
BlockStore3(large_shared.storage3).Store(out_sizes, row, max_lora_count * MODULE_COUNT * 3);
|
||||
}
|
||||
break;
|
||||
case LDA_INDEX:
|
||||
{
|
||||
int64_t input_hidden_size_64[1] = {static_cast<int64_t>(input_hidden_size)};
|
||||
BlockStore64().Store(lda, input_hidden_size_64, max_lora_count * MODULE_COUNT);
|
||||
}
|
||||
break;
|
||||
case LDD_INDEX:
|
||||
{
|
||||
int64_t max_lora_rank_64[1] = {static_cast<int64_t>(max_lora_rank)};
|
||||
BlockStore64().Store(ldd, max_lora_rank_64, max_lora_count * MODULE_COUNT);
|
||||
}
|
||||
break;
|
||||
case LDB_PRIME_INDEX:
|
||||
{
|
||||
int64_t rank = slot_ranks[lora_id];
|
||||
BlockStore64().Store(ldb_prime, as_singleton_array(rank), max_lora_count * MODULE_COUNT);
|
||||
}
|
||||
break;
|
||||
case LDD_PRIME_INDEX:
|
||||
{
|
||||
int64_t sum_output_hidden_size_64[1] = {static_cast<int64_t>(sum_output_hidden_size)};
|
||||
BlockStore64().Store(ldd_prime, sum_output_hidden_size_64, max_lora_count * MODULE_COUNT);
|
||||
}
|
||||
break;
|
||||
case A_PTRS_INDEX:
|
||||
{
|
||||
int64_t slot_offset = 0;
|
||||
BlockLoad64().Load(slot_offsets, as_singleton_array(slot_offset), max_lora_count + 1);
|
||||
if (linearIdx == max_lora_count)
|
||||
{
|
||||
row_count_to_gather = static_cast<int32_t>(slot_offset);
|
||||
}
|
||||
slot_offset *= input_hidden_size;
|
||||
slot_offset *= dtype_element_size;
|
||||
slot_offset += a_base;
|
||||
for (int i = 0; i < MODULE_COUNT; i++)
|
||||
{
|
||||
BlockStore64().Store(a_ptrs + i * max_lora_count, as_singleton_array(slot_offset), max_lora_count);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case D_PTRS_INDEX:
|
||||
{
|
||||
int64_t slot_offset = 0;
|
||||
BlockLoad64().Load(slot_offsets, as_singleton_array(slot_offset), max_lora_count + 1);
|
||||
for (int i = 0; i < MODULE_COUNT; i++)
|
||||
{
|
||||
int64_t offset = slot_offset;
|
||||
offset += i * batch_size;
|
||||
offset *= max_lora_rank;
|
||||
offset *= dtype_element_size;
|
||||
offset += d_base;
|
||||
BlockStore64().Store(d_ptrs + i * max_lora_count, as_singleton_array(offset), max_lora_count);
|
||||
}
|
||||
if (linearIdx == max_lora_count)
|
||||
{
|
||||
row_count_to_gather = static_cast<int32_t>(slot_offset);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case D_PRIME_PTRS_INDEX:
|
||||
{
|
||||
int64_t slot_offset = 0;
|
||||
BlockLoad64().Load(slot_offsets, as_singleton_array(slot_offset), max_lora_count + 1);
|
||||
if (linearIdx == max_lora_count)
|
||||
{
|
||||
row_count_to_gather = static_cast<int32_t>(slot_offset);
|
||||
}
|
||||
slot_offset *= sum_output_hidden_size;
|
||||
for (int i = 0; i < MODULE_COUNT; i++)
|
||||
{
|
||||
int64_t offset = slot_offset;
|
||||
offset += module_out_prefix[i];
|
||||
offset *= dtype_element_size;
|
||||
offset += d_prime_base;
|
||||
BlockStore64().Store(d_prime_ptrs + i * max_lora_count, as_singleton_array(offset), max_lora_count);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case SPLITK_OFFSETS_INDEX:
|
||||
{
|
||||
int64_t slot_count = slot_counts[lora_id];
|
||||
int64_t rank = slot_ranks[lora_id];
|
||||
int64_t b_ptr = b_ptrs[linearIdx % (max_lora_count * MODULE_COUNT)];
|
||||
int64_t splitk_offset = (b_ptr == 0) ? 0 : (slot_count * rank);
|
||||
BlockScan(large_shared.scan).ExclusiveSum(splitk_offset, splitk_offset);
|
||||
BlockStore64().Store(splitk_offsets, as_singleton_array(splitk_offset), max_lora_count * MODULE_COUNT);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Set row_count_to_gather for non-pointer blocks
|
||||
switch (blockLinearIdx)
|
||||
{
|
||||
case A_PTRS_INDEX:
|
||||
case D_PTRS_INDEX:
|
||||
case D_PRIME_PTRS_INDEX: break;
|
||||
default:
|
||||
if (linearIdx == 0)
|
||||
{
|
||||
row_count_to_gather = static_cast<int32_t>(slot_offsets[max_lora_count]);
|
||||
}
|
||||
}
|
||||
|
||||
int constexpr ITEM_PER_THREAD = VECTOR_LOAD_WIDTH;
|
||||
using BlockStoreRow = cub::BlockStore<uint8_t, BlockDim, ITEM_PER_THREAD, cub::BLOCK_STORE_VECTORIZE>;
|
||||
|
||||
{
|
||||
// Write zero to intermediate buffer and output buffer
|
||||
auto intermediate_cast = reinterpret_cast<uint8_t*>(d_base);
|
||||
auto model_output_cast = reinterpret_cast<uint8_t*>(d_prime_base);
|
||||
|
||||
int intermediate_size = MODULE_COUNT * batch_size * max_lora_rank * dtype_element_size;
|
||||
int output_size = batch_size * sum_output_hidden_size * dtype_element_size;
|
||||
|
||||
uint8_t all_zeroes[ITEM_PER_THREAD] = {0};
|
||||
|
||||
int const blockOffset = THREADS_PER_BLOCK * ITEM_PER_THREAD * blockLinearIdx;
|
||||
BlockStoreRow().Store(intermediate_cast + blockOffset, all_zeroes, intermediate_size - blockOffset);
|
||||
BlockStoreRow().Store(model_output_cast + blockOffset, all_zeroes, output_size - blockOffset);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Row gather
|
||||
if (blockIdx.y < batch_size)
|
||||
{
|
||||
using BlockLoadRow = cub::BlockLoad<uint8_t, BlockDim, ITEM_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE>;
|
||||
|
||||
auto const row_size = input_hidden_size * dtype_element_size;
|
||||
|
||||
auto output_cast = reinterpret_cast<uint8_t*>(reordered_input);
|
||||
|
||||
uint8_t tile[ITEM_PER_THREAD] = {0};
|
||||
int constexpr x_stride = THREADS_PER_BLOCK * ITEM_PER_THREAD;
|
||||
int const y_stride = row_size;
|
||||
int tail = row_size - blockIdx.x * x_stride;
|
||||
if (blockIdx.y < row_count_to_gather)
|
||||
{
|
||||
auto const input_cast = reinterpret_cast<uint8_t const*>(input);
|
||||
auto const src_row = sorted_ids[blockIdx.y];
|
||||
BlockLoadRow().Load(input_cast + blockIdx.x * x_stride + src_row * y_stride, tile, tail);
|
||||
}
|
||||
BlockStoreRow().Store(output_cast + blockIdx.x * x_stride + blockIdx.y * y_stride, tile, tail);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Launch function that instantiates the appropriate kernel based on module count.
|
||||
*/
|
||||
template <int BlockDim>
|
||||
void launchKernelWithModuleCount(int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs, int64_t* d_ptrs,
|
||||
int64_t* d_prime_ptrs, int64_t* lda, int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime, int64_t* splitk_offsets,
|
||||
void* reordered_input, int32_t max_lora_count, int32_t max_lora_rank, int32_t sum_output_hidden_size,
|
||||
int32_t input_hidden_size, int64_t dtype_element_size, int64_t batch_size, int64_t a_base, int64_t d_base,
|
||||
int64_t d_prime_base, int32_t const* slot_counts, int32_t const* slot_ranks, int64_t const* slot_offsets,
|
||||
int32_t const* module_out_sizes, int64_t const* module_out_prefix, int64_t const* b_ptrs,
|
||||
int64_t const* b_prime_ptrs, void const* input, int64_t const* sorted_ids, int32_t module_count,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
int constexpr THREADS_PER_BLOCK = BlockDim;
|
||||
|
||||
// Grid dimensions for row gather
|
||||
int constexpr ITEMS_PER_BLOCK = THREADS_PER_BLOCK * VECTOR_LOAD_WIDTH;
|
||||
int const gridDimX = common::ceilDiv(input_hidden_size * dtype_element_size, ITEMS_PER_BLOCK);
|
||||
int gridDimY = std::max(
|
||||
static_cast<int>(common::ceilDiv(static_cast<int>(PARAM_COUNT), gridDimX)), static_cast<int>(batch_size));
|
||||
|
||||
// calculate threads needed for writing zeros to intermediate buffer and output buffer
|
||||
int const itemsPerRow = ITEMS_PER_BLOCK * gridDimX;
|
||||
gridDimY = std::max(gridDimY,
|
||||
common::ceilDiv(static_cast<int>(module_count * batch_size * max_lora_rank * dtype_element_size), itemsPerRow));
|
||||
gridDimY = std::max(gridDimY,
|
||||
common::ceilDiv(static_cast<int>(batch_size * sum_output_hidden_size * dtype_element_size), itemsPerRow));
|
||||
|
||||
dim3 grid(gridDimX, gridDimY);
|
||||
dim3 block(BlockDim);
|
||||
|
||||
auto* reordered_input_cast = reinterpret_cast<uint8_t*>(reordered_input);
|
||||
auto const* input_cast = reinterpret_cast<uint8_t const*>(input);
|
||||
|
||||
// Dispatch based on module count
|
||||
switch (module_count)
|
||||
{
|
||||
case 1:
|
||||
loraGroupGEMMParamFillRowReorderFusionKernel<BlockDim, 1><<<grid, block, 0, stream>>>(in_sizes, out_sizes,
|
||||
a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast,
|
||||
max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size,
|
||||
a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix,
|
||||
b_ptrs, b_prime_ptrs, input_cast, sorted_ids);
|
||||
break;
|
||||
case 2:
|
||||
loraGroupGEMMParamFillRowReorderFusionKernel<BlockDim, 2><<<grid, block, 0, stream>>>(in_sizes, out_sizes,
|
||||
a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast,
|
||||
max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size,
|
||||
a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix,
|
||||
b_ptrs, b_prime_ptrs, input_cast, sorted_ids);
|
||||
break;
|
||||
case 3:
|
||||
loraGroupGEMMParamFillRowReorderFusionKernel<BlockDim, 3><<<grid, block, 0, stream>>>(in_sizes, out_sizes,
|
||||
a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast,
|
||||
max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size,
|
||||
a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix,
|
||||
b_ptrs, b_prime_ptrs, input_cast, sorted_ids);
|
||||
break;
|
||||
case 4:
|
||||
loraGroupGEMMParamFillRowReorderFusionKernel<BlockDim, 4><<<grid, block, 0, stream>>>(in_sizes, out_sizes,
|
||||
a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime, ldd_prime, splitk_offsets, reordered_input_cast,
|
||||
max_lora_count, max_lora_rank, sum_output_hidden_size, input_hidden_size, dtype_element_size, batch_size,
|
||||
a_base, d_base, d_prime_base, slot_counts, slot_ranks, slot_offsets, module_out_sizes, module_out_prefix,
|
||||
b_ptrs, b_prime_ptrs, input_cast, sorted_ids);
|
||||
break;
|
||||
default: TLLM_CHECK_WITH_INFO(false, "Unsupported module_count: %d (max 4)", module_count);
|
||||
}
|
||||
}
|
||||
|
||||
void launchLoraGroupGEMMParamFillRowReorderFusion(int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs,
|
||||
int64_t* d_ptrs, int64_t* d_prime_ptrs, int64_t* lda, int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime,
|
||||
int64_t* splitk_offsets, void* reordered_input, int32_t max_lora_count, int32_t max_lora_rank,
|
||||
int32_t sum_output_hidden_size, int32_t input_hidden_size, int64_t dtype_element_size, int64_t batch_size,
|
||||
int64_t a_base, int64_t d_base, int64_t d_prime_base, int32_t const* slot_counts, int32_t const* slot_ranks,
|
||||
int64_t const* slot_offsets, int32_t const* module_out_sizes, int64_t const* module_out_prefix,
|
||||
int64_t const* b_ptrs, int64_t const* b_prime_ptrs, void const* input, int64_t const* sorted_ids,
|
||||
int32_t module_count, nvinfer1::DataType dtype, cudaStream_t stream)
|
||||
{
|
||||
// Determine block dimensions (1D)
|
||||
// Requirements: 1) >= max_lora_count * module_count 2) >= 256 3) divisible by 32
|
||||
|
||||
int constexpr MIN_THREADS = 256;
|
||||
int constexpr WARP_SIZE = 32;
|
||||
|
||||
int const min_threads_needed = max_lora_count * module_count;
|
||||
int const threads_per_block = std::max(MIN_THREADS, common::ceilDiv(min_threads_needed, WARP_SIZE) * WARP_SIZE);
|
||||
|
||||
if (threads_per_block == 256)
|
||||
{
|
||||
launchKernelWithModuleCount<256>(in_sizes, out_sizes, a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime,
|
||||
ldd_prime, splitk_offsets, reordered_input, max_lora_count, max_lora_rank, sum_output_hidden_size,
|
||||
input_hidden_size, dtype_element_size, batch_size, a_base, d_base, d_prime_base, slot_counts, slot_ranks,
|
||||
slot_offsets, module_out_sizes, module_out_prefix, b_ptrs, b_prime_ptrs, input, sorted_ids, module_count,
|
||||
stream);
|
||||
}
|
||||
else if (threads_per_block == 288)
|
||||
{
|
||||
launchKernelWithModuleCount<288>(in_sizes, out_sizes, a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime,
|
||||
ldd_prime, splitk_offsets, reordered_input, max_lora_count, max_lora_rank, sum_output_hidden_size,
|
||||
input_hidden_size, dtype_element_size, batch_size, a_base, d_base, d_prime_base, slot_counts, slot_ranks,
|
||||
slot_offsets, module_out_sizes, module_out_prefix, b_ptrs, b_prime_ptrs, input, sorted_ids, module_count,
|
||||
stream);
|
||||
}
|
||||
else if (threads_per_block == 320)
|
||||
{
|
||||
launchKernelWithModuleCount<320>(in_sizes, out_sizes, a_ptrs, d_ptrs, d_prime_ptrs, lda, ldd, ldb_prime,
|
||||
ldd_prime, splitk_offsets, reordered_input, max_lora_count, max_lora_rank, sum_output_hidden_size,
|
||||
input_hidden_size, dtype_element_size, batch_size, a_base, d_base, d_prime_base, slot_counts, slot_ranks,
|
||||
slot_offsets, module_out_sizes, module_out_prefix, b_ptrs, b_prime_ptrs, input, sorted_ids, module_count,
|
||||
stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false,
|
||||
"Unsupported threads_per_block: %d (calculated from max_lora_count=%d * module_count=%d)",
|
||||
threads_per_block, max_lora_count, module_count);
|
||||
}
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
@ -0,0 +1,77 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/config.h"
|
||||
#include <NvInferRuntime.h>
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
/**
|
||||
* @brief Fused kernel that fills group GEMM parameters, performs row reordering, zero fillings for CUDA graph
|
||||
* compatible LoRA.
|
||||
*
|
||||
* @param in_sizes Output: [module_count, max_lora_count, 3] problem sizes for first GEMM
|
||||
* @param out_sizes Output: [module_count, max_lora_count, 3] problem sizes for second GEMM
|
||||
* @param a_ptrs Output: [module_count, max_lora_count] input matrix pointers
|
||||
* @param d_ptrs Output: [module_count, max_lora_count] intermediate output pointers
|
||||
* @param d_prime_ptrs Output: [module_count, max_lora_count] final output pointers
|
||||
* @param lda Output: [module_count, max_lora_count] leading dimensions for A matrices
|
||||
* @param ldd Output: [module_count, max_lora_count] leading dimensions for D matrices
|
||||
* @param ldb_prime Output: [module_count, max_lora_count] leading dimensions for B' matrices
|
||||
* @param ldd_prime Output: [module_count, max_lora_count] leading dimensions for D' matrices
|
||||
* @param splitk_offsets Output: [module_count, max_lora_count] split-K work offsets
|
||||
* @param reordered_input Output: [batch_size, input_hidden_size] reordered input matrix
|
||||
* @param max_lora_count Maximum number of LoRA adapters
|
||||
* @param max_lora_rank Maximum rank of LoRA adapters
|
||||
* @param sum_output_hidden_size Sum of output hidden sizes across modules
|
||||
* @param input_hidden_size Input hidden dimension
|
||||
* @param dtype_element_size Size of data type in bytes
|
||||
* @param batch_size Batch size
|
||||
* @param a_base Base pointer for input matrices
|
||||
* @param d_base Base pointer for intermediate output matrices
|
||||
* @param d_prime_base Base pointer for final output matrices
|
||||
* @param slot_counts Input: [max_lora_count] number of requests per LoRA slot
|
||||
* @param slot_ranks Input: [max_lora_count] rank of each LoRA adapter
|
||||
* @param slot_offsets Input: [max_lora_count + 1] cumulative offsets (last element = total rows)
|
||||
* @param module_out_sizes Input: [module_count] output hidden size per module
|
||||
* @param module_out_prefix Input: [module_count] prefix sum of output hidden sizes
|
||||
* @param b_ptrs Input: [module_count, max_lora_count] weight pointers for first GEMM
|
||||
* @param b_prime_ptrs Input: [module_count, max_lora_count] weight pointers for second GEMM
|
||||
* @param input Input: [batch_size, input_hidden_size] original input matrix
|
||||
* @param sorted_ids Input: [batch_size] indices for row reordering
|
||||
* @param module_count Number of modules per layer
|
||||
* @param dtype Data type of matrices
|
||||
* @param stream CUDA stream
|
||||
*/
|
||||
void launchLoraGroupGEMMParamFillRowReorderFusion(int32_t* in_sizes, int32_t* out_sizes, int64_t* a_ptrs,
|
||||
int64_t* d_ptrs, int64_t* d_prime_ptrs, int64_t* lda, int64_t* ldd, int64_t* ldb_prime, int64_t* ldd_prime,
|
||||
int64_t* splitk_offsets, void* reordered_input, int32_t max_lora_count, int32_t max_lora_rank,
|
||||
int32_t sum_output_hidden_size, int32_t input_hidden_size, int64_t dtype_element_size, int64_t batch_size,
|
||||
int64_t a_base, int64_t d_base, int64_t d_prime_base, int32_t const* slot_counts, int32_t const* slot_ranks,
|
||||
int64_t const* slot_offsets, int32_t const* module_out_sizes, int64_t const* module_out_prefix,
|
||||
int64_t const* b_ptrs, int64_t const* b_prime_ptrs, void const* input, int64_t const* sorted_ids,
|
||||
int32_t module_count, nvinfer1::DataType dtype, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
@ -34,6 +34,7 @@
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/tuple.h>
|
||||
#include <nanobind/stl/unique_ptr.h>
|
||||
#include <nanobind/stl/unordered_map.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
#include <nanobind/trampoline.h>
|
||||
#include <torch/extension.h>
|
||||
@ -563,6 +564,11 @@ void tb::BasePeftCacheManagerBindings::initBindings(nb::module_& m)
|
||||
nb::arg("config"), nb::arg("model_config"), nb::arg("world_config"), nb::arg("buffer_manager"),
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("is_task_cached", &tb::PeftCacheManager::isTaskCached, nb::arg("taskId"),
|
||||
nb::call_guard<nb::gil_scoped_release>())
|
||||
.def("is_task_cached_device", &tb::PeftCacheManager::isTaskCachedDevice, nb::arg("taskId"),
|
||||
nb::call_guard<nb::gil_scoped_release>()) // ;
|
||||
.def("ensure_batch_map_task_id", &tb::PeftCacheManager::ensureBatchMapTaskId, nb::arg("context_requests"),
|
||||
nb::arg("generation_requests"), nb::arg("reset_gpu_cache") = false,
|
||||
nb::call_guard<nb::gil_scoped_release>());
|
||||
|
||||
nb::class_<tb::NoOpPeftCacheManager, tb::BasePeftCacheManager>(m, "NoOpPeftCacheManager")
|
||||
|
||||
@ -558,6 +558,11 @@ void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
|
||||
py::arg("config"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_task_cached", &tb::PeftCacheManager::isTaskCached, py::arg("taskId"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_task_cached_device", &tb::PeftCacheManager::isTaskCachedDevice, py::arg("taskId"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("ensure_batch_map_task_id", &tb::PeftCacheManager::ensureBatchMapTaskId, py::arg("context_requests"),
|
||||
py::arg("generation_requests"), py::arg("reset_gpu_cache") = false,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
py::classh<tb::NoOpPeftCacheManager, tb::BasePeftCacheManager>(m, "NoOpPeftCacheManager")
|
||||
|
||||
@ -42,7 +42,7 @@ public:
|
||||
using LoraReqTensors = std::tuple<LoraWeightsTensorPtr, LoraConfigTensorPtr>;
|
||||
using TaskIdType = std::int64_t;
|
||||
using PeftValues = std::vector<runtime::LoraCache::TaskLayerModuleConfig>;
|
||||
using PeftTable = std::map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
|
||||
using PeftTable = std::unordered_map<uint64_t, std::vector<runtime::LoraCache::TaskLayerModuleConfig>>;
|
||||
|
||||
explicit LoraManager() {}
|
||||
|
||||
|
||||
@ -18,13 +18,16 @@
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/kernels/cuda_graph_grouped_gemm.h"
|
||||
#include "tensorrt_llm/kernels/lora/lora.h"
|
||||
#include "tensorrt_llm/kernels/lora/loraGroupGEMMParamFillRowReorderFusion.h"
|
||||
#include "tensorrt_llm/kernels/selectiveScan/selectiveScan.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
namespace th = torch;
|
||||
namespace tk = tensorrt_llm::kernels;
|
||||
using tensorrt_llm::common::fmtstr;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
TRTLLM_NAMESPACE_BEGIN
|
||||
|
||||
@ -174,6 +177,171 @@ std::vector<th::Tensor> lora_grouped_gemm(th::Tensor const& input, th::Tensor co
|
||||
return output_torch;
|
||||
}
|
||||
|
||||
void lora_grouped_gemm_cuda_graph(th::Tensor const& lora_in_sizes, // [layer_module_num, max_lora_size, 3]
|
||||
th::Tensor const& lora_out_sizes, // [layer_module_num, max_lora_size, 3]
|
||||
th::Tensor const& a_offsets, // [layer_module_num, max_lora_size]
|
||||
th::Tensor const& b_ptrs, // [layer_module_num, max_lora_size]
|
||||
th::Tensor const& d_offsets, // [layer_module_num, max_lora_size]
|
||||
th::Tensor const& b_prime_ptrs, // [layer_module_num, max_lora_size]
|
||||
th::Tensor const& d_prime_offsets, // [layer_module_num, max_lora_size]
|
||||
int64_t problem_count,
|
||||
th::Tensor const& lda, // Leading dimensions for A matrices [layer_module_num, max_lora_size]
|
||||
th::Tensor const& ldb, // Leading dimensions for B matrices [layer_module_num, max_lora_size]
|
||||
th::Tensor const& ldd, // Leading dimensions for C matrices [layer_module_num, max_lora_size] (unused)
|
||||
th::Tensor const& ldb_prime, th::Tensor const& ldd_prime, th::Tensor const& host_max_in_sizes,
|
||||
th::Tensor const& host_max_out_sizes, th::Tensor const& splitk_offsets, c10::ScalarType dtype, int64_t minKN,
|
||||
int64_t splitKSlices = 16)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
auto* a_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(a_offsets.data_ptr()));
|
||||
auto* d_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(d_offsets.data_ptr()));
|
||||
auto* a_prime_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(d_offsets.data_ptr()));
|
||||
auto* d_prime_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(d_prime_offsets.data_ptr()));
|
||||
|
||||
auto* problem_sizes_1_ptr = reinterpret_cast<cutlass::gemm::GemmCoord*>(lora_in_sizes.data_ptr());
|
||||
auto* problem_sizes_2_ptr = reinterpret_cast<cutlass::gemm::GemmCoord*>(lora_out_sizes.data_ptr());
|
||||
|
||||
auto* host_max_in_sizes_ptr = reinterpret_cast<cutlass::gemm::GemmCoord*>(host_max_in_sizes.data_ptr());
|
||||
auto* host_max_out_sizes_ptr = reinterpret_cast<cutlass::gemm::GemmCoord*>(host_max_out_sizes.data_ptr());
|
||||
|
||||
auto* b_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(b_ptrs.data_ptr()));
|
||||
auto* b_prime_ptrs_gpu = reinterpret_cast<void**>(const_cast<void*>(b_prime_ptrs.data_ptr()));
|
||||
|
||||
auto* lda_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(lda.data_ptr()));
|
||||
auto* ldb_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(ldb.data_ptr()));
|
||||
auto* ldd_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(ldd.data_ptr()));
|
||||
auto* ldb_prime_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(ldb_prime.data_ptr()));
|
||||
auto* ldd_prime_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(ldd_prime.data_ptr()));
|
||||
|
||||
auto* splitk_offsets_gpu = reinterpret_cast<int64_t*>(const_cast<void*>(splitk_offsets.data_ptr()));
|
||||
|
||||
// Get data type
|
||||
nvinfer1::DataType loraRuntimeDataType;
|
||||
switch (dtype)
|
||||
{
|
||||
case torch::kFloat16: loraRuntimeDataType = nvinfer1::DataType::kHALF; break;
|
||||
case torch::kBFloat16: loraRuntimeDataType = nvinfer1::DataType::kBF16; break;
|
||||
default: TORCH_CHECK(false, "Invalid dtype, only supports float16, bfloat16, got %s", c10::toString(dtype));
|
||||
}
|
||||
|
||||
int const minKnInt = std::max(1, static_cast<int>(minKN));
|
||||
|
||||
// Call CUDA Graph compatible grouped GEMM for lora_in (split-K)
|
||||
if (problem_count > 0)
|
||||
{
|
||||
TLLM_LOG_TRACE("Start Grouped GEMM for LoRA in.");
|
||||
|
||||
tk::cudaGraphSplitKGroupedGemm(problem_sizes_1_ptr, problem_count, a_ptrs_gpu, b_ptrs_gpu,
|
||||
d_ptrs_gpu, // ptrC (no bias)
|
||||
d_ptrs_gpu, lda_gpu, ldb_gpu, ldd_gpu, ldd_gpu, // Precomputed leading dimensions
|
||||
true, // isLoraIn
|
||||
loraRuntimeDataType,
|
||||
static_cast<int>(splitKSlices), // splitKSlices
|
||||
minKnInt, // minKN
|
||||
host_max_in_sizes_ptr, splitk_offsets_gpu, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
// Call CUDA Graph compatible grouped GEMM for lora_out
|
||||
TLLM_LOG_TRACE("Start Grouped GEMM for LoRA out.");
|
||||
tk::cudaGraphGroupedGemm(problem_sizes_2_ptr, problem_count, a_prime_ptrs_gpu, b_prime_ptrs_gpu,
|
||||
d_prime_ptrs_gpu, // ptrC (no bias)
|
||||
d_prime_ptrs_gpu, ldd_gpu, ldb_prime_gpu, ldd_prime_gpu, ldd_prime_gpu, // Precomputed leading dimensions
|
||||
false, // isLoraIn
|
||||
loraRuntimeDataType,
|
||||
minKnInt, // minKN
|
||||
host_max_out_sizes_ptr, stream);
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void lora_group_gemm_param_fill_row_reorder_fusion(th::Tensor const& in_sizes, // [module_count, max_lora_count, 3]
|
||||
th::Tensor const& out_sizes, // [module_count, max_lora_count, 3]
|
||||
th::Tensor const& a_ptrs, // [module_count, max_lora_count]
|
||||
th::Tensor const& d_ptrs, // [module_count, max_lora_count]
|
||||
th::Tensor const& d_prime_ptrs, // [module_count, max_lora_count]
|
||||
th::Tensor const& lda, // [module_count, max_lora_count]
|
||||
th::Tensor const& ldd, // [module_count, max_lora_count]
|
||||
th::Tensor const& ldb_prime, // [module_count, max_lora_count]
|
||||
th::Tensor const& ldd_prime, // [module_count, max_lora_count]
|
||||
th::Tensor const& splitk_offsets, // [module_count, max_lora_count]
|
||||
th::Tensor const& reordered_input, // [batch_size, input_hidden_size]
|
||||
int64_t max_lora_count, int64_t max_lora_rank, int64_t sum_output_hidden_size, int64_t input_hidden_size,
|
||||
int64_t batch_size,
|
||||
th::Tensor const& slot_counts, // [max_lora_count]
|
||||
th::Tensor const& slot_ranks, // [max_lora_count]
|
||||
th::Tensor const& slot_offsets, // [max_lora_count + 1]
|
||||
th::Tensor const& module_out_sizes, // [module_count]
|
||||
th::Tensor const& module_out_prefix, // [module_count]
|
||||
th::Tensor const& b_ptrs, // [module_count, max_lora_count]
|
||||
th::Tensor const& b_prime_ptrs, // [module_count, max_lora_count]
|
||||
th::Tensor const& input, // [batch_size, input_hidden_size]
|
||||
th::Tensor const& sorted_ids, // [batch_size]
|
||||
th::Tensor const& intermediate_buffer, // [batch_size, max_lora_rank]
|
||||
th::Tensor const& output_buffer, // [batch_size, sum_output_hidden_size]
|
||||
c10::ScalarType dtype)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
// Validate inputs
|
||||
TORCH_CHECK(in_sizes.device().is_cuda(), "in_sizes must be a CUDA tensor");
|
||||
TORCH_CHECK(out_sizes.device().is_cuda(), "out_sizes must be a CUDA tensor");
|
||||
TORCH_CHECK(reordered_input.device().is_cuda(), "reordered_input must be a CUDA tensor");
|
||||
TORCH_CHECK(input.device().is_cuda(), "input must be a CUDA tensor");
|
||||
|
||||
// Get module count from tensor shapes
|
||||
int32_t const module_count = static_cast<int32_t>(in_sizes.size(0));
|
||||
|
||||
// Get data type info
|
||||
nvinfer1::DataType loraRuntimeDataType;
|
||||
switch (dtype)
|
||||
{
|
||||
case torch::kFloat16: loraRuntimeDataType = nvinfer1::DataType::kHALF; break;
|
||||
case torch::kBFloat16: loraRuntimeDataType = nvinfer1::DataType::kBF16; break;
|
||||
default: TORCH_CHECK(false, "Invalid dtype, only supports float16, bfloat16, got %s", c10::toString(dtype));
|
||||
}
|
||||
|
||||
int64_t const dtype_element_size = input.element_size();
|
||||
|
||||
int64_t const a_base_ptr = reinterpret_cast<int64_t>(reordered_input.data_ptr());
|
||||
int64_t const d_base_ptr = reinterpret_cast<int64_t>(intermediate_buffer.data_ptr());
|
||||
int64_t const d_prime_base_ptr = reinterpret_cast<int64_t>(output_buffer.data_ptr());
|
||||
|
||||
tk::launchLoraGroupGEMMParamFillRowReorderFusion(reinterpret_cast<int32_t*>(const_cast<void*>(in_sizes.data_ptr())),
|
||||
reinterpret_cast<int32_t*>(const_cast<void*>(out_sizes.data_ptr())),
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(a_ptrs.data_ptr())),
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(d_ptrs.data_ptr())),
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(d_prime_ptrs.data_ptr())),
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(lda.data_ptr())),
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(ldd.data_ptr())),
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(ldb_prime.data_ptr())),
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(ldd_prime.data_ptr())),
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(splitk_offsets.data_ptr())),
|
||||
const_cast<void*>(reordered_input.data_ptr()), static_cast<int32_t>(max_lora_count),
|
||||
static_cast<int32_t>(max_lora_rank), static_cast<int32_t>(sum_output_hidden_size),
|
||||
static_cast<int32_t>(input_hidden_size), dtype_element_size, batch_size, a_base_ptr, d_base_ptr,
|
||||
d_prime_base_ptr, reinterpret_cast<int32_t const*>(slot_counts.data_ptr()),
|
||||
reinterpret_cast<int32_t const*>(slot_ranks.data_ptr()),
|
||||
reinterpret_cast<int64_t const*>(slot_offsets.data_ptr()),
|
||||
reinterpret_cast<int32_t const*>(module_out_sizes.data_ptr()),
|
||||
reinterpret_cast<int64_t const*>(module_out_prefix.data_ptr()),
|
||||
reinterpret_cast<int64_t const*>(b_ptrs.data_ptr()), reinterpret_cast<int64_t const*>(b_prime_ptrs.data_ptr()),
|
||||
input.data_ptr(), reinterpret_cast<int64_t const*>(sorted_ids.data_ptr()), module_count, loraRuntimeDataType,
|
||||
stream);
|
||||
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TRTLLM_NAMESPACE_END
|
||||
@ -192,9 +360,65 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
"int max_low_rank, "
|
||||
"int weight_index, "
|
||||
"bool isRemoveInputPadding) -> Tensor[]");
|
||||
|
||||
m.def(
|
||||
"lora_grouped_gemm_cuda_graph("
|
||||
"Tensor lora_in_sizes, "
|
||||
"Tensor lora_out_sizes, "
|
||||
"Tensor a_offsets, "
|
||||
"Tensor b_ptrs, "
|
||||
"Tensor d_offsets, "
|
||||
"Tensor b_prime_ptrs, "
|
||||
"Tensor d_prime_offsets, "
|
||||
"int problem_count, "
|
||||
"Tensor lda, "
|
||||
"Tensor ldb, "
|
||||
"Tensor ldd, "
|
||||
"Tensor ldb_prime, "
|
||||
"Tensor ldd_prime, "
|
||||
"Tensor host_max_in_sizes, "
|
||||
"Tensor host_max_out_sizes, "
|
||||
"Tensor splitk_offsets, "
|
||||
"ScalarType dtype, "
|
||||
"int minKN, "
|
||||
"int splitKSlices=16) -> ()");
|
||||
|
||||
m.def(
|
||||
"lora_group_gemm_param_fill_row_reorder_fusion("
|
||||
"Tensor in_sizes, "
|
||||
"Tensor out_sizes, "
|
||||
"Tensor a_ptrs, "
|
||||
"Tensor d_ptrs, "
|
||||
"Tensor d_prime_ptrs, "
|
||||
"Tensor lda, "
|
||||
"Tensor ldd, "
|
||||
"Tensor ldb_prime, "
|
||||
"Tensor ldd_prime, "
|
||||
"Tensor splitk_offsets, "
|
||||
"Tensor reordered_input, "
|
||||
"int max_lora_count, "
|
||||
"int max_lora_rank, "
|
||||
"int sum_output_hidden_size, "
|
||||
"int input_hidden_size, "
|
||||
"int batch_size, "
|
||||
"Tensor slot_counts, "
|
||||
"Tensor slot_ranks, "
|
||||
"Tensor slot_offsets, "
|
||||
"Tensor module_out_sizes, "
|
||||
"Tensor module_out_prefix, "
|
||||
"Tensor b_ptrs, "
|
||||
"Tensor b_prime_ptrs, "
|
||||
"Tensor input, "
|
||||
"Tensor sorted_ids, "
|
||||
"Tensor intermediate_buffer, "
|
||||
"Tensor output_buffer, "
|
||||
"ScalarType dtype) -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("lora_grouped_gemm", &tensorrt_llm::torch_ext::lora_grouped_gemm);
|
||||
m.impl("lora_grouped_gemm_cuda_graph", &tensorrt_llm::torch_ext::lora_grouped_gemm_cuda_graph);
|
||||
m.impl("lora_group_gemm_param_fill_row_reorder_fusion",
|
||||
&tensorrt_llm::torch_ext::lora_group_gemm_param_fill_row_reorder_fusion);
|
||||
}
|
||||
|
||||
@ -317,9 +317,6 @@ class Attention(nn.Module):
|
||||
self.fused_qkv_lora = LoraLayer([LoraModuleType.ATTENTION_QKV],
|
||||
[self.q_size + 2 * self.kv_size])
|
||||
|
||||
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
|
||||
[self.hidden_size])
|
||||
|
||||
# Whether to fuse RoPE into the attention OP.
|
||||
# If true, RoPE will be applied in self.attn.forward.
|
||||
# If false, RoPE will be applied in self.apply_rope.
|
||||
|
||||
130
tensorrt_llm/_torch/peft/lora/adapter_slot_manager.py
Normal file
130
tensorrt_llm/_torch/peft/lora/adapter_slot_manager.py
Normal file
@ -0,0 +1,130 @@
|
||||
"""
|
||||
AdapterSlotManager for managing slots that stores LoRA indices.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional
|
||||
|
||||
from ...pyexecutor.resource_manager import PeftCacheManager
|
||||
from ...pyexecutor.scheduler import RequestList
|
||||
|
||||
|
||||
class AdapterSlotManager:
|
||||
"""
|
||||
Manages max_num_adapters ordered slots for distinct task_ids to enable CUDA Graph compatibility.
|
||||
|
||||
Each slot can hold one adapter (task_id) and maintains a consistent ordering that allows
|
||||
the CUDA Graph to be captured with fixed buffer layouts.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_adapters: int):
|
||||
"""
|
||||
Initialize the AdapterSlotManager.
|
||||
|
||||
Args:
|
||||
max_num_adapters: Maximum number of LoRA adapters that can be active simultaneously
|
||||
"""
|
||||
self.max_num_adapters = max_num_adapters
|
||||
|
||||
# Slot management
|
||||
self.slot2task: List[Optional[int]] = [None] * max_num_adapters
|
||||
self.task2slot: OrderedDict[int, int] = OrderedDict() # represent LRU order
|
||||
|
||||
# State tracking
|
||||
self.slots_changed = False
|
||||
|
||||
def find_free_slot(self) -> int:
|
||||
"""
|
||||
Find a free slot. Return slot_id if found, otherwise return None.
|
||||
"""
|
||||
return self.slot2task.index(None)
|
||||
|
||||
def remove_task(self, task_id: int) -> Optional[int]:
|
||||
"""
|
||||
Remove a task_id from slots. Return its slot_id if present otherwise return None.
|
||||
"""
|
||||
slot_id = self.task2slot.pop(task_id, None)
|
||||
if slot_id is not None:
|
||||
self.slots_changed = True
|
||||
self.slot2task[slot_id] = None
|
||||
return slot_id
|
||||
|
||||
def get_or_assign_task(self, task_id: int) -> tuple[int, Optional[int]]:
|
||||
"""
|
||||
Assign a task_id to a slot and do LRU eviction if necessary.
|
||||
If already in any slot, update LRU order.
|
||||
Return: pair (assigned slot_id, evicted task_id)
|
||||
"""
|
||||
evicted_task = None
|
||||
if task_id in self.task2slot:
|
||||
self.task2slot.move_to_end(task_id)
|
||||
else:
|
||||
self.slots_changed = True
|
||||
if len(self.task2slot) < self.max_num_adapters:
|
||||
free_slot = self.find_free_slot()
|
||||
self.slot2task[free_slot] = task_id
|
||||
self.task2slot[task_id] = free_slot
|
||||
else:
|
||||
# evict lru
|
||||
evicted_task, evicted_slot = self.task2slot.popitem(last=False)
|
||||
self.slot2task[evicted_slot] = task_id
|
||||
self.task2slot[task_id] = evicted_slot
|
||||
return self.task2slot[task_id], evicted_task
|
||||
|
||||
def remove_evicted_slots_in_cpp(self, peft_cache_manager: PeftCacheManager):
|
||||
"""
|
||||
Validate slots by removing tasks that are not cached in PeftCacheManager.
|
||||
"""
|
||||
for task_id in self.slot2task:
|
||||
if task_id is not None:
|
||||
if not peft_cache_manager.is_task_cached_device(task_id):
|
||||
self.remove_task(task_id)
|
||||
|
||||
def update_slots(
|
||||
self, requests: RequestList, peft_cache_manager: PeftCacheManager
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get slot mapping for all requests in a scheduled batch.
|
||||
|
||||
Args:
|
||||
scheduled_requests: The scheduled requests for the current batch
|
||||
|
||||
Returns:
|
||||
Dict mapping request_id to slot_id, with slot_id=max_num_adapters for base model requests
|
||||
"""
|
||||
# remove task evicted in PeftCacheManager in C++
|
||||
self.remove_evicted_slots_in_cpp(peft_cache_manager)
|
||||
|
||||
# check if total number of unique tasks in the requests is not larger than max_num_adapters
|
||||
tasks = [request.lora_task_id for request in requests]
|
||||
unique_tasks = {t for t in tasks if t is not None}
|
||||
assert len(unique_tasks) <= self.max_num_adapters, (
|
||||
f"Batch with more unique LoRA adapters ({len(unique_tasks)}) than max_num_adapters={self.max_num_adapters} "
|
||||
"is not supported"
|
||||
)
|
||||
|
||||
# assign slots to tasks
|
||||
for i, task in enumerate(tasks):
|
||||
if task is None:
|
||||
tasks[i] = self.max_num_adapters
|
||||
else:
|
||||
tasks[i], evicted_task = self.get_or_assign_task(task)
|
||||
|
||||
return tasks
|
||||
|
||||
def get_slot_to_task_mapping(self) -> tuple[Optional[int], ...]:
|
||||
"""
|
||||
Get current slot to task mapping.
|
||||
|
||||
Returns:
|
||||
Tuple mapping slot_id to task_id (or None if slot is empty)
|
||||
"""
|
||||
return tuple(self.slot2task)
|
||||
|
||||
def has_slots_changed(self) -> bool:
|
||||
"""Check if slot assignments have changed since last check."""
|
||||
return self.slots_changed
|
||||
|
||||
def reset_slots_changed(self):
|
||||
"""Reset the slots_changed flag."""
|
||||
self.slots_changed = False
|
||||
175
tensorrt_llm/_torch/peft/lora/cuda_graph_lora_manager.py
Normal file
175
tensorrt_llm/_torch/peft/lora/cuda_graph_lora_manager.py
Normal file
@ -0,0 +1,175 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...._utils import nvtx_range
|
||||
from ....logger import logger
|
||||
from ....lora_manager import LoraManager, LoraModelConfig
|
||||
from ...attention_backend.interface import AttentionMetadata
|
||||
from ...pyexecutor.resource_manager import PeftCacheManager
|
||||
from ...pyexecutor.scheduler import ScheduledRequests
|
||||
from .adapter_slot_manager import AdapterSlotManager
|
||||
from .cuda_graph_lora_params import CudaGraphLoraParams
|
||||
from .layer import LoraLayer
|
||||
|
||||
|
||||
class CudaGraphLoraManager:
|
||||
"""
|
||||
Manager that coordinates adapter slots and CUDA Graph compatible LoRA parameters.
|
||||
|
||||
This class bridges the gap between the current LoRA implementation and the new
|
||||
CUDA Graph compatible design by managing adapter slots and preparing persistent
|
||||
device tensors for group GEMM operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_lora_size: int,
|
||||
max_batch_size: int,
|
||||
max_lora_rank: int,
|
||||
model: torch.nn.Module,
|
||||
lora_model_config: Optional[LoraModelConfig],
|
||||
device: str = "cuda",
|
||||
):
|
||||
"""
|
||||
Initialize the CUDA Graph LoRA manager.
|
||||
|
||||
Args:
|
||||
max_lora_size: Maximum number of LoRA adapters that can be active
|
||||
max_batch_size: Maximum batch size for CUDA graphs
|
||||
max_lora_rank: Maximum LoRA rank across all layers
|
||||
model: Model to get layerwise LoRA info
|
||||
lora_model_config: LoRA model configuration
|
||||
device: Device to allocate tensors on
|
||||
"""
|
||||
self.max_lora_size = max_lora_size
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_lora_rank = max_lora_rank
|
||||
self.device = device
|
||||
|
||||
self.adapter_slot_manager = AdapterSlotManager(max_lora_size)
|
||||
self.lora_model_config = lora_model_config
|
||||
lora_target_modules = lora_model_config.lora_target_modules
|
||||
self.target_modules_ids: Optional[tuple[int, ...]] = (
|
||||
tuple(map(LoraManager.LORA_MODULE_IDS.__getitem__, lora_target_modules))
|
||||
if bool(lora_target_modules)
|
||||
else None
|
||||
)
|
||||
if not self.target_modules_ids:
|
||||
logger.debug(
|
||||
"No LoRA target modules provided in LoRA config, using all modules in PyTorch Module!"
|
||||
)
|
||||
|
||||
# Single CudaGraphLoraParams instance for all batch sizes. Its shape will be allocated to accommodate the max
|
||||
# batch size.
|
||||
self.cuda_graph_lora_params: Optional[CudaGraphLoraParams] = None
|
||||
self.layer_info: (
|
||||
Dict[CudaGraphLoraParams.LoraLayerKey, CudaGraphLoraParams.LoraLayerInfo] | None
|
||||
) = None
|
||||
# Initialize layer_info from model.
|
||||
self._initialize_from_model(model)
|
||||
self.cuda_graph_lora_params = CudaGraphLoraParams(
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_lora_size=self.max_lora_size,
|
||||
max_rank=self.max_lora_rank,
|
||||
layer_info=self.layer_info,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def _initialize_from_model(self, model: torch.nn.Module):
|
||||
"""
|
||||
Initialize LoRALayerInfo from model.
|
||||
"""
|
||||
self.layer_info = dict()
|
||||
|
||||
def get_layer_idx(
|
||||
model: torch.nn.Module, lora_module: LoraLayer, lora_module_name: str
|
||||
) -> Optional[int]:
|
||||
"""Find the layer index of the given LoRA module in the model."""
|
||||
module = lora_module
|
||||
name = lora_module_name
|
||||
while module is not None and (
|
||||
(not hasattr(module, "layer_idx")) or module.layer_idx is None
|
||||
):
|
||||
name = name.rsplit(".", 1)
|
||||
name = name[0] if len(name) > 1 else None
|
||||
if name is not None:
|
||||
module = model.get_submodule(name)
|
||||
else:
|
||||
module = None
|
||||
if hasattr(module, "layer_idx") and module.layer_idx is not None:
|
||||
return module.layer_idx
|
||||
return None
|
||||
|
||||
# Ignore LoRA layers without at least one of the target modules.
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
layer_idx = get_layer_idx(model, module, name)
|
||||
# if target_modules_ids is None, by default enable all modules
|
||||
if self.target_modules_ids and not any(
|
||||
module_id in self.target_modules_ids for module_id in module.lora_module_types
|
||||
):
|
||||
logger.debug(f"Layer {name} does not have any of the target modules, skipping")
|
||||
continue
|
||||
layer_key = CudaGraphLoraParams.LoraLayerKey(
|
||||
layer_idx=layer_idx, module_ids=tuple(module.lora_module_types)
|
||||
)
|
||||
assert layer_key not in self.layer_info, f"Layer {layer_key} already exists"
|
||||
|
||||
self.layer_info[layer_key] = CudaGraphLoraParams.LoraLayerInfo(
|
||||
module_num=len(module.lora_module_types),
|
||||
output_sizes=module.output_hidden_sizes,
|
||||
)
|
||||
|
||||
@nvtx_range("prepare_cuda_graph_lora_params")
|
||||
def prepare_cuda_graph_lora_params(
|
||||
self,
|
||||
scheduled_requests: "ScheduledRequests",
|
||||
attn_metadata: "AttentionMetadata",
|
||||
peft_cache_manager: PeftCacheManager,
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
Prepare LoRA parameters from scheduled requests.
|
||||
|
||||
Args:
|
||||
scheduled_requests: The scheduled requests for the current batch
|
||||
attn_metadata: Attention metadata containing batch information
|
||||
peft_table: PEFT table from cache manager mapping task_id to layer-module-configs
|
||||
|
||||
Returns:
|
||||
LoRA parameters dictionary.
|
||||
"""
|
||||
assert len(scheduled_requests.context_requests) == 0, (
|
||||
"Context requests are not supported with LoRA CUDA Graph path. "
|
||||
f"Have {len(scheduled_requests.context_requests)} context requests"
|
||||
)
|
||||
request_list = scheduled_requests.generation_requests
|
||||
|
||||
peft_table = peft_cache_manager.get_and_reset_batch_peft_table()
|
||||
|
||||
# Update slot manager and get slot assignments for this batch
|
||||
request_slot_ids = self.adapter_slot_manager.update_slots(request_list, peft_cache_manager)
|
||||
|
||||
cuda_graph_lora_params = self.cuda_graph_lora_params
|
||||
cuda_graph_lora_params.update_sorted_indices(request_slot_ids)
|
||||
|
||||
# Get current slot to task mapping
|
||||
slot2task = self.adapter_slot_manager.get_slot_to_task_mapping()
|
||||
|
||||
# Update weight pointers if slot assignments changed
|
||||
if self.adapter_slot_manager.has_slots_changed():
|
||||
cuda_graph_lora_params.update_weight_pointers(peft_table, slot2task)
|
||||
self.adapter_slot_manager.reset_slots_changed()
|
||||
|
||||
# Update GEMM sizes and prefix sums using batch
|
||||
cuda_graph_lora_params.update_slots_params(batch_slot_ids=request_slot_ids)
|
||||
|
||||
lora_params = {
|
||||
"cuda_graph_params": cuda_graph_lora_params,
|
||||
"host_request_types": attn_metadata.host_request_types,
|
||||
"prompt_lens_cpu": attn_metadata.prompt_lens_cpu,
|
||||
"num_seqs": attn_metadata.num_seqs,
|
||||
"use_cuda_graph_mode": True, # Flag to indicate new mode
|
||||
}
|
||||
|
||||
return lora_params
|
||||
341
tensorrt_llm/_torch/peft/lora/cuda_graph_lora_params.py
Normal file
341
tensorrt_llm/_torch/peft/lora/cuda_graph_lora_params.py
Normal file
@ -0,0 +1,341 @@
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoraLayerParams:
|
||||
"""
|
||||
Parameters for a single LoRA layer.
|
||||
All tensors are persistent device tensors that can be updated outside of graph replay.
|
||||
"""
|
||||
|
||||
# Weight pointers
|
||||
# Shape: [layer_module_num, max_lora_size]
|
||||
d_b_ptrs: torch.Tensor # Lora_in weight pointers
|
||||
d_b_prime_ptrs: torch.Tensor # Lora_out weight pointers
|
||||
h_b_ptrs: torch.Tensor # Lora_in weight pointers in host
|
||||
h_b_prime_ptrs: torch.Tensor # Lora_out weight pointers in host
|
||||
|
||||
d_output_sizes: torch.Tensor
|
||||
d_output_sizes_offset: torch.Tensor
|
||||
h_output_sizes: torch.Tensor
|
||||
h_output_sizes_offset: torch.Tensor
|
||||
|
||||
|
||||
class CudaGraphLoraParams:
|
||||
"""
|
||||
CUDA Graph compatible LoRA parameters for all layers and batch management.
|
||||
|
||||
This structure maintains persistent device tensors that can be updated outside
|
||||
of CUDA Graph replay to support different LoRA combinations per batch.
|
||||
"""
|
||||
|
||||
LoraLayerKey = namedtuple("LoraLayerKey", ["layer_idx", "module_ids"])
|
||||
|
||||
PTR_DTYPE = torch.int64
|
||||
LD_DTYPE = torch.int64
|
||||
SIZES_DTYPE = torch.int32
|
||||
|
||||
@dataclass
|
||||
class LoraLayerInfo:
|
||||
module_num: int = 0
|
||||
output_sizes: List[int] | torch.Tensor | None = None
|
||||
input_hidden_size: int = 0
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
return self.input_hidden_size > 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_batch_size: int,
|
||||
max_lora_size: int,
|
||||
max_rank: int,
|
||||
layer_info: Dict[LoraLayerKey, LoraLayerInfo],
|
||||
device: str = "cuda",
|
||||
):
|
||||
"""
|
||||
Initialize CUDA Graph compatible LoRA parameters.
|
||||
|
||||
Args:
|
||||
max_batch_size: Maximum batch size for this graph
|
||||
max_lora_size: Maximum number of LoRA adapters
|
||||
max_rank: Maximum rank for all layers
|
||||
layers_info: Layer information for each layer
|
||||
device: Device to allocate tensors on
|
||||
dtype: Data type for size and offset tensors
|
||||
"""
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_lora_size = max_lora_size
|
||||
self.max_rank = max_rank
|
||||
self.layer_info = layer_info
|
||||
self.layer_module2key = self._calculate_layer_module2key()
|
||||
self.device = device
|
||||
|
||||
self.layer_params: Dict[self.LoraLayerKey, LoraLayerParams] = dict()
|
||||
|
||||
# sorted indices using slot ids as keys, mainly to group requests with the same slot id together in a batch
|
||||
self.sorted_ids = torch.zeros(max_batch_size, dtype=torch.int64, device=device)
|
||||
self.sorted_ids_host = torch.zeros_like(self.sorted_ids, device="cpu", pin_memory=True)
|
||||
|
||||
# persistent values for gen-only batch with cuda graph
|
||||
self.persistent_sorted_ids = self.sorted_ids
|
||||
|
||||
self.slot_ids = torch.zeros(max_batch_size, dtype=torch.int64, device=device)
|
||||
|
||||
self.slot_counts = torch.zeros(max_lora_size, dtype=torch.int32, device=device)
|
||||
self.slot_counts_host = torch.zeros_like(self.slot_counts, device="cpu", pin_memory=True)
|
||||
self.slot_offsets_full = torch.zeros(max_lora_size + 1, dtype=torch.int64, device=device)
|
||||
self.slot_offsets = self.slot_offsets_full[:-1]
|
||||
self.slot_offsets_full_host = torch.zeros_like(
|
||||
self.slot_offsets_full, device="cpu", pin_memory=True
|
||||
)
|
||||
|
||||
self.slot_ranks = torch.zeros(max_lora_size, dtype=torch.int32, device=device)
|
||||
self.slot_ranks_host = torch.zeros_like(self.slot_ranks, device="cpu", pin_memory=True)
|
||||
|
||||
for key, info in self.layer_info.items():
|
||||
assert (
|
||||
info.module_num > 0
|
||||
and info.output_sizes is not None
|
||||
and len(info.output_sizes) == info.module_num
|
||||
)
|
||||
# Allocate layer parameters
|
||||
self.layer_params[key] = self._allocate_layer_params(
|
||||
key, info.module_num, info.output_sizes
|
||||
)
|
||||
|
||||
def _calculate_layer_module2key(self) -> Dict[Tuple[int, int], LoraLayerKey]:
|
||||
layer_module2key = dict()
|
||||
for key in self.layer_info.keys():
|
||||
layer_id = key.layer_idx
|
||||
module_ids = key.module_ids
|
||||
for module_id in module_ids:
|
||||
layer_module2key[(layer_id, module_id)] = key
|
||||
return layer_module2key
|
||||
|
||||
def _allocate_layer_params(
|
||||
self, key: LoraLayerKey, layer_module_num: int, module_output_sizes: torch.Tensor
|
||||
) -> LoraLayerParams:
|
||||
"""
|
||||
Create LoraLayerParams for a specific layer.
|
||||
|
||||
Args:
|
||||
key: Key of the layer
|
||||
layer_module_num: Number of modules in this layer
|
||||
module_output_sizes: Output sizes for each module in this layer
|
||||
|
||||
Returns:
|
||||
LoraLayerParams for the specified layer
|
||||
"""
|
||||
# GEMM parameter tensors only need max_lora_size (no dummy slot for base model)
|
||||
# Base model requests are handled separately and don't participate in GEMM operations
|
||||
shape_2d = (layer_module_num, self.max_lora_size)
|
||||
|
||||
output_hidden_sizes = torch.tensor(module_output_sizes, dtype=self.SIZES_DTYPE)
|
||||
output_hidden_sizes_device = output_hidden_sizes.to(device="cuda")
|
||||
|
||||
output_sizes_offset = self.get_offset_from_counts(output_hidden_sizes).to(
|
||||
dtype=self.PTR_DTYPE
|
||||
) # [num_layer_modules]
|
||||
output_sizes_offset_device = output_sizes_offset.to(device="cuda")
|
||||
|
||||
return LoraLayerParams(
|
||||
# Weight pointers - managed by PEFT cache manager
|
||||
d_b_ptrs=torch.zeros(shape_2d, dtype=torch.int64, device=self.device),
|
||||
d_b_prime_ptrs=torch.zeros(shape_2d, dtype=torch.int64, device=self.device),
|
||||
h_b_ptrs=torch.zeros(shape_2d, dtype=torch.int64, pin_memory=True),
|
||||
h_b_prime_ptrs=torch.zeros(shape_2d, dtype=torch.int64, pin_memory=True),
|
||||
d_output_sizes=output_hidden_sizes_device,
|
||||
d_output_sizes_offset=output_sizes_offset_device,
|
||||
h_output_sizes=output_hidden_sizes,
|
||||
h_output_sizes_offset=output_sizes_offset,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_sorted_indices(slot_ids: List[int]) -> torch.Tensor:
|
||||
"""
|
||||
Get sorted indices for the given slot IDs.
|
||||
"""
|
||||
slot_ids = torch.tensor(slot_ids, dtype=torch.int64)
|
||||
|
||||
# Compute sorted indices for gather/scatter operations
|
||||
sorted_slot_ids, sorted_indices = torch.sort(slot_ids, stable=True)
|
||||
return sorted_indices
|
||||
|
||||
def update_sorted_indices(self, slot_ids: List[int]):
|
||||
"""
|
||||
Update slot IDs for the current batch and compute sorted indices.
|
||||
|
||||
Args:
|
||||
slot_ids: List of slot IDs for each token in the batch
|
||||
actual_batch_size: Actual batch size (may be less than max_batch_size)
|
||||
"""
|
||||
actual_batch_size = len(slot_ids)
|
||||
assert actual_batch_size <= self.max_batch_size, (
|
||||
f"Actual batch size {actual_batch_size} exceeds max {self.max_batch_size}"
|
||||
)
|
||||
sorted_indices = self.get_sorted_indices(slot_ids)
|
||||
|
||||
# Update sorted_ids tensor with the computed indices
|
||||
assert actual_batch_size <= self.max_batch_size, (
|
||||
f"CudaGraphLoraParams: Actual batch size {actual_batch_size} exceeds max {self.max_batch_size}!"
|
||||
)
|
||||
if actual_batch_size <= self.max_batch_size:
|
||||
# if can fit in persistent, use it
|
||||
self.sorted_ids = self.persistent_sorted_ids
|
||||
sorted_ids_host = self.sorted_ids_host[:actual_batch_size]
|
||||
sorted_ids_host.copy_(sorted_indices)
|
||||
self.sorted_ids[:actual_batch_size].copy_(sorted_ids_host, non_blocking=True)
|
||||
else:
|
||||
# otherwise not an gen-only batch, use new allocated sorted_ids
|
||||
self.sorted_ids = sorted_indices.to(device=self.device)
|
||||
|
||||
def update_weight_pointers(
|
||||
self, peft_table: Dict[int, List], slot_to_task_mapping: tuple[Optional[int], ...]
|
||||
):
|
||||
"""
|
||||
Update weight pointers from PEFT cache manager.
|
||||
|
||||
Args:
|
||||
peft_table: PEFT table from cache manager containing weight pointers, map task id to list of layer
|
||||
module configs
|
||||
slot_to_task_mapping: Mapping from slot_id to task_id, tuple of None for empty slots
|
||||
"""
|
||||
|
||||
# get slot ranks
|
||||
# assume ranks are the same for a given slot,
|
||||
# input_hidden_size are the same within a layer
|
||||
# output sizes are the same for all slots with the same module
|
||||
def zero_out_weight_pointers(slot_id: int):
|
||||
"""
|
||||
Zero out all weight pointers for a given slot_id for all layers
|
||||
"""
|
||||
for layer_param in self.layer_params.values():
|
||||
layer_param.h_b_ptrs[:, slot_id] = 0
|
||||
layer_param.h_b_prime_ptrs[:, slot_id] = 0
|
||||
|
||||
for slot_id in range(self.max_lora_size):
|
||||
task_id = slot_to_task_mapping[slot_id]
|
||||
if task_id is None: # empty slot
|
||||
self.slot_ranks_host[slot_id] = 0
|
||||
zero_out_weight_pointers(slot_id)
|
||||
elif (
|
||||
task_id not in peft_table
|
||||
): # task has not changed in the slot, retain old rank / weight pointers
|
||||
continue
|
||||
else: # task might have changed in the slot, update its rank
|
||||
task_configs = peft_table[task_id]
|
||||
config = task_configs[0] # assume all layerModuleConfigs have the same rank
|
||||
self.slot_ranks_host[slot_id] = config.adapter_size
|
||||
|
||||
zero_out_weight_pointers(
|
||||
slot_id
|
||||
) # in case new task in slot do not have LoRA adapter for some module in some layer
|
||||
for config in task_configs:
|
||||
layer_id = config.layer_id
|
||||
module_id = config.module_id
|
||||
key = self.layer_module2key[(layer_id, module_id)]
|
||||
layer_param = self.layer_params[key]
|
||||
local_module_id = key.module_ids.index(module_id)
|
||||
|
||||
assert key in self.layer_params, (
|
||||
f"Layer {layer_id} not found in layer_params, assumption that all LoRA has their adapters on "
|
||||
"the same layers is broken"
|
||||
)
|
||||
|
||||
# Validate LoRA rank
|
||||
rank = config.adapter_size
|
||||
assert rank <= self.max_rank, (
|
||||
f"LoRA rank {rank} in layer {layer_id} exceeds configured max_rank {self.max_rank}. "
|
||||
)
|
||||
|
||||
layer_param.h_b_ptrs[local_module_id, slot_id] = config.weights_in_pointer
|
||||
layer_param.h_b_prime_ptrs[local_module_id, slot_id] = (
|
||||
config.weights_out_pointer
|
||||
)
|
||||
|
||||
self.slot_ranks.copy_(self.slot_ranks_host, non_blocking=True)
|
||||
|
||||
for layer_param in self.layer_params.values():
|
||||
layer_param.d_b_ptrs.copy_(layer_param.h_b_ptrs, non_blocking=True)
|
||||
layer_param.d_b_prime_ptrs.copy_(layer_param.h_b_prime_ptrs, non_blocking=True)
|
||||
|
||||
@staticmethod
|
||||
def get_offset_from_counts(
|
||||
counts: torch.Tensor, full: bool = False, out: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
if full:
|
||||
offset = torch.empty(counts.shape[0] + 1, dtype=torch.int64, device=counts.device)
|
||||
else:
|
||||
offset = torch.empty(counts.shape[0], dtype=torch.int64, device=counts.device)
|
||||
else:
|
||||
assert (full and out.shape[0] == counts.shape[0] + 1) or (
|
||||
(not full) and out.shape[0] == counts.shape[0]
|
||||
)
|
||||
offset = out
|
||||
|
||||
offset[0] = 0
|
||||
|
||||
if full:
|
||||
offset[1:] = counts
|
||||
else:
|
||||
offset[1:] = counts[:-1]
|
||||
offset[1:].cumsum_(dim=0)
|
||||
return offset
|
||||
|
||||
@staticmethod
|
||||
def get_slot_counts(batch_slot_ids: List[int], max_lora_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Get the number of tokens for each slot_id in the batch.
|
||||
"""
|
||||
slot_counts = torch.bincount(
|
||||
torch.tensor(batch_slot_ids, dtype=torch.int32), minlength=max_lora_size
|
||||
)
|
||||
assert slot_counts.size(0) <= max_lora_size + 1
|
||||
slot_counts = slot_counts[:max_lora_size]
|
||||
return slot_counts
|
||||
|
||||
def update_slots_params(self, batch_slot_ids: List[int]):
|
||||
"""
|
||||
Update GEMM sizes and buffer offsets based on current batch composition.
|
||||
|
||||
Args:
|
||||
batch_slot_ids: Slot IDs for each token in the batch
|
||||
"""
|
||||
slot_counts = self.get_slot_counts(batch_slot_ids, self.max_lora_size)
|
||||
self.slot_counts_host.copy_(slot_counts)
|
||||
self.get_offset_from_counts(slot_counts, full=True, out=self.slot_offsets_full_host)
|
||||
self.slot_counts.copy_(self.slot_counts_host, non_blocking=True)
|
||||
self.slot_offsets_full.copy_(self.slot_offsets_full_host, non_blocking=True)
|
||||
|
||||
def get_problem_count(self, layer_key: LoraLayerKey) -> int:
|
||||
"""
|
||||
Get the number of GEMM problems for a layer.
|
||||
|
||||
Args:
|
||||
layer_key: Key of the layer
|
||||
|
||||
Returns:
|
||||
Number of GEMM problems (layer_module_num * max_lora_size)
|
||||
Returns 0 if layer has no LoRA modules
|
||||
Note: Only actual LoRA slots are counted, not the dummy base model slot
|
||||
"""
|
||||
if layer_key not in self.layer_params:
|
||||
return 0 # Layer has no LoRA modules
|
||||
return self.layer_info[layer_key].module_num * self.max_lora_size
|
||||
|
||||
def get_layer_params(self, layer_key: LoraLayerKey) -> Optional[LoraLayerParams]:
|
||||
"""
|
||||
Get LoRA parameters for a specific layer.
|
||||
|
||||
Args:
|
||||
layer_key: Key of the layer
|
||||
|
||||
Returns:
|
||||
LoraLayerParams for the specified layer, or None if layer has no LoRA modules
|
||||
"""
|
||||
return self.layer_params.get(layer_key)
|
||||
@ -1,8 +1,48 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .cuda_graph_lora_params import CudaGraphLoraParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupedGemmParamsOutput:
|
||||
in_sizes: Optional[torch.Tensor] = None
|
||||
out_sizes: Optional[torch.Tensor] = None
|
||||
a_offset: Optional[torch.Tensor] = None
|
||||
d_offset: Optional[torch.Tensor] = None
|
||||
d_prime_offset: Optional[torch.Tensor] = None
|
||||
lda: Optional[torch.Tensor] = None
|
||||
ldb: Optional[torch.Tensor] = None
|
||||
ldd: Optional[torch.Tensor] = None
|
||||
ldb_prime: Optional[torch.Tensor] = None
|
||||
ldd_prime: Optional[torch.Tensor] = None
|
||||
splitk_offsets: Optional[torch.Tensor] = None
|
||||
reordered_input: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupedGemmParamsInput:
|
||||
x: torch.Tensor
|
||||
output_buffer: torch.Tensor
|
||||
intermediate_buffer: torch.Tensor
|
||||
max_lora_size: int
|
||||
max_rank: int
|
||||
slot_counts: torch.Tensor
|
||||
slot_ranks: torch.Tensor
|
||||
slot_offsets_full: torch.Tensor
|
||||
b_ptrs: torch.Tensor
|
||||
b_prime_ptrs: torch.Tensor
|
||||
sorted_ids: torch.Tensor
|
||||
output_hidden_sizes: torch.Tensor
|
||||
output_sizes_offset: torch.Tensor
|
||||
|
||||
@property
|
||||
def slot_offsets(self):
|
||||
return self.slot_offsets_full[:-1]
|
||||
|
||||
|
||||
class LoraModuleType(IntEnum):
|
||||
"""Enum class representing different types of modules that can have LoRA adapters.
|
||||
@ -100,57 +140,385 @@ class LoraLayer(torch.nn.Module):
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
if bool(lora_params):
|
||||
lora_ranks = []
|
||||
lora_weight_pointers = []
|
||||
active_lora_module_ids = []
|
||||
for module_idx in self.lora_module_types:
|
||||
module_idx = int(module_idx)
|
||||
if module_idx in lora_params[layer_idx]:
|
||||
active_lora_module_ids.append(module_idx)
|
||||
lora_ranks.append(
|
||||
lora_params[layer_idx][module_idx]['adapter_size'])
|
||||
lora_weight_pointers.append(
|
||||
lora_params[layer_idx][module_idx]['weight_pointers'])
|
||||
# Check if we're using CUDA Graph mode
|
||||
use_cuda_graph_mode = lora_params.get('use_cuda_graph_mode', False)
|
||||
|
||||
num_seqs = lora_params['num_seqs']
|
||||
|
||||
if len(active_lora_module_ids) == 0:
|
||||
return None
|
||||
if use_cuda_graph_mode:
|
||||
return self._forward_cuda_graph_mode(x, lora_params, layer_idx)
|
||||
else:
|
||||
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
|
||||
x,
|
||||
lora_params['host_request_types'][:num_seqs],
|
||||
lora_ranks,
|
||||
lora_weight_pointers,
|
||||
lora_params['prompt_lens_cpu'][:num_seqs],
|
||||
self.output_hidden_sizes,
|
||||
False, # transA
|
||||
True, # transB
|
||||
max([r.max() for r in lora_ranks]),
|
||||
0,
|
||||
True, # TODO smor- should be lora_params["remove_input_padding"], support in loraOp as well
|
||||
)
|
||||
if isinstance(lora_outputs, torch.Tensor):
|
||||
return lora_outputs
|
||||
else:
|
||||
# For multiple LoRA modules, some might not be executed in grouped gemm.
|
||||
# For those modules not executed, we create zero tensors with matching dimensions.
|
||||
# Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order.
|
||||
lora_output = []
|
||||
for module_idx in self.lora_module_types:
|
||||
if int(module_idx) in active_lora_module_ids:
|
||||
lora_output.append(lora_outputs.pop(0))
|
||||
else:
|
||||
lora_output.append(
|
||||
torch.zeros(list(x.shape[:-1]) + [
|
||||
self.output_hidden_sizes[
|
||||
self.lora_module_types.index(
|
||||
module_idx)]
|
||||
],
|
||||
dtype=x.dtype,
|
||||
device=x.device))
|
||||
lora_output = torch.cat(lora_output, dim=-1)
|
||||
return lora_output
|
||||
|
||||
return self._forward_eager_mode(x, lora_params, layer_idx)
|
||||
else:
|
||||
return None
|
||||
|
||||
def prepare_grouped_gemm_buffers(self, input: GroupedGemmParamsInput):
|
||||
device = input.x.device
|
||||
bs, input_hidden_size = input.x.shape
|
||||
shape_2d = (len(self.lora_module_types), input.max_lora_size
|
||||
) # [num_layer_modules, max_lora_size]
|
||||
shape_3d = shape_2d + (3, )
|
||||
sum_out_sizes = sum(self.output_hidden_sizes)
|
||||
|
||||
input.output_buffer.fill_(0)
|
||||
input.intermediate_buffer.fill_(0)
|
||||
|
||||
# reorder input
|
||||
reordered_input = torch.index_select(input.x, 0, input.sorted_ids[:bs])
|
||||
|
||||
# a [bs, hidden]
|
||||
lda = torch.full(shape_2d,
|
||||
input_hidden_size,
|
||||
dtype=CudaGraphLoraParams.LD_DTYPE,
|
||||
device=device)
|
||||
|
||||
# b [input_hidden_size, lora_rank]
|
||||
ldb = lda
|
||||
|
||||
# a_prime / d [num_layer_modules, bs, max_rank]
|
||||
ldd = torch.full(shape_2d,
|
||||
input.max_rank,
|
||||
dtype=CudaGraphLoraParams.LD_DTYPE,
|
||||
device=device)
|
||||
|
||||
# b_prime [lora_rank, module_output_size]
|
||||
ldb_prime = input.slot_ranks.unsqueeze(0).to(
|
||||
dtype=CudaGraphLoraParams.LD_DTYPE).repeat(shape_2d[0], 1)
|
||||
|
||||
# d_prime [bs, sum_of_each_module_output_sizes]
|
||||
ldd_prime = torch.full(shape_2d,
|
||||
sum_out_sizes,
|
||||
dtype=CudaGraphLoraParams.LD_DTYPE,
|
||||
device=device)
|
||||
|
||||
# reordered a [bs, hidden], each module has the same offset
|
||||
a_offset = input.slot_offsets * input_hidden_size
|
||||
a_offset = a_offset.unsqueeze(0).repeat(shape_2d[0], 1)
|
||||
|
||||
# d [num_layer_modules, bs, max_rank]
|
||||
d_offset = (input.slot_offsets.unsqueeze(0) + torch.arange(
|
||||
shape_2d[0], device=device, dtype=CudaGraphLoraParams.PTR_DTYPE).
|
||||
unsqueeze(1) * bs) * input.max_rank
|
||||
|
||||
# d' [bs, sum_of_each_module_output_sizes]
|
||||
bs_offset = input.slot_offsets.unsqueeze(0) # [1, max_lora_size]
|
||||
bs_offset = bs_offset * sum_out_sizes
|
||||
out_offset = input.output_sizes_offset.unsqueeze(
|
||||
1) # [num_layer_modules, 1]
|
||||
d_prime_offset = bs_offset + out_offset
|
||||
|
||||
# sizes
|
||||
in_sizes = torch.empty(shape_3d,
|
||||
dtype=CudaGraphLoraParams.SIZES_DTYPE,
|
||||
device=device)
|
||||
out_sizes = torch.empty_like(in_sizes)
|
||||
|
||||
slot_counts = input.slot_counts.unsqueeze(0) # [1, max_lora_size]
|
||||
ranks = input.slot_ranks.unsqueeze(0) # [1, max_lora_size]
|
||||
output_hidden_sizes = input.output_hidden_sizes.unsqueeze(
|
||||
1) # [num_layer_modules, 1]
|
||||
|
||||
in_sizes[:, :, 0] = slot_counts
|
||||
in_sizes[:, :, 1] = ranks
|
||||
in_sizes[:, :, 2] = input_hidden_size
|
||||
|
||||
out_sizes[:, :, 0] = slot_counts
|
||||
out_sizes[:, :, 1] = output_hidden_sizes
|
||||
out_sizes[:, :, 2] = ranks
|
||||
|
||||
# disable unused modules / lora with ptr being zeros
|
||||
in_sizes *= (input.b_ptrs != 0).unsqueeze(-1)
|
||||
out_sizes *= (input.b_prime_ptrs != 0).unsqueeze(-1)
|
||||
|
||||
# splitk_offsets: [num_layer_modules, max_lora_size]
|
||||
# splitk offtsets (m * n) for the first grouped gemm with (m, n, k) = (slot_counts, slot_ranks, input_hidden_size)
|
||||
splitk_offsets = torch.zeros(shape_2d,
|
||||
dtype=CudaGraphLoraParams.LD_DTYPE,
|
||||
device=device)
|
||||
|
||||
splitk_offsets.view(-1)[1:] = in_sizes.view(-1, 3)[:-1, 0] # = M
|
||||
splitk_offsets.view(-1)[1:] *= in_sizes.view(-1, 3)[:-1, 1] # *= N
|
||||
splitk_offsets.view(-1).cumsum_(dim=0)
|
||||
|
||||
# add base addresses to offset tensors on GPU
|
||||
dtype_element_size = input.x.element_size()
|
||||
a_offset *= dtype_element_size
|
||||
a_offset += reordered_input.data_ptr()
|
||||
|
||||
d_offset *= dtype_element_size
|
||||
d_offset += input.intermediate_buffer.data_ptr()
|
||||
|
||||
d_prime_offset *= dtype_element_size
|
||||
d_prime_offset += input.output_buffer.data_ptr()
|
||||
|
||||
return GroupedGemmParamsOutput(in_sizes=in_sizes,
|
||||
out_sizes=out_sizes,
|
||||
a_offset=a_offset,
|
||||
d_offset=d_offset,
|
||||
d_prime_offset=d_prime_offset,
|
||||
lda=lda,
|
||||
ldb=ldb,
|
||||
ldd=ldd,
|
||||
ldb_prime=ldb_prime,
|
||||
ldd_prime=ldd_prime,
|
||||
splitk_offsets=splitk_offsets,
|
||||
reordered_input=reordered_input)
|
||||
|
||||
def _prepare_grouped_gemm_buffers_fused(self,
|
||||
input: GroupedGemmParamsInput):
|
||||
device = input.x.device
|
||||
bs, input_hidden_size = input.x.shape
|
||||
shape_2d = (len(self.lora_module_types), input.max_lora_size
|
||||
) # [num_layer_modules, max_lora_size]
|
||||
shape_3d = shape_2d + (3, )
|
||||
sum_out_sizes = sum(self.output_hidden_sizes)
|
||||
|
||||
in_sizes = torch.empty(shape_3d,
|
||||
dtype=CudaGraphLoraParams.SIZES_DTYPE,
|
||||
device=device)
|
||||
out_sizes = torch.empty_like(in_sizes)
|
||||
a_offset = torch.empty(shape_2d,
|
||||
dtype=CudaGraphLoraParams.PTR_DTYPE,
|
||||
device=device)
|
||||
d_offset = torch.empty_like(a_offset)
|
||||
d_prime_offset = torch.empty_like(a_offset)
|
||||
lda = torch.empty(shape_2d,
|
||||
dtype=CudaGraphLoraParams.LD_DTYPE,
|
||||
device=device)
|
||||
ldb = lda
|
||||
ldd = torch.empty_like(lda)
|
||||
ldb_prime = torch.empty_like(lda)
|
||||
ldd_prime = torch.empty_like(lda)
|
||||
splitk_offsets = torch.empty(shape_2d,
|
||||
dtype=CudaGraphLoraParams.LD_DTYPE,
|
||||
device=device)
|
||||
reordered_input = torch.empty_like(input.x)
|
||||
torch.ops.trtllm.lora_group_gemm_param_fill_row_reorder_fusion(
|
||||
# output parameters
|
||||
in_sizes,
|
||||
out_sizes,
|
||||
a_offset,
|
||||
d_offset,
|
||||
d_prime_offset,
|
||||
lda,
|
||||
ldd,
|
||||
ldb_prime,
|
||||
ldd_prime,
|
||||
splitk_offsets,
|
||||
reordered_input,
|
||||
|
||||
# input parameters
|
||||
input.max_lora_size,
|
||||
input.max_rank,
|
||||
sum_out_sizes,
|
||||
input_hidden_size,
|
||||
bs, # batch_size
|
||||
input.slot_counts,
|
||||
input.slot_ranks,
|
||||
input.slot_offsets,
|
||||
input.output_hidden_sizes,
|
||||
input.output_sizes_offset,
|
||||
input.b_ptrs,
|
||||
input.b_prime_ptrs,
|
||||
input.x,
|
||||
input.sorted_ids[:bs],
|
||||
input.intermediate_buffer,
|
||||
input.output_buffer,
|
||||
input.x.dtype)
|
||||
|
||||
return GroupedGemmParamsOutput(in_sizes=in_sizes,
|
||||
out_sizes=out_sizes,
|
||||
a_offset=a_offset,
|
||||
d_offset=d_offset,
|
||||
d_prime_offset=d_prime_offset,
|
||||
lda=lda,
|
||||
ldb=ldb,
|
||||
ldd=ldd,
|
||||
ldb_prime=ldb_prime,
|
||||
ldd_prime=ldd_prime,
|
||||
splitk_offsets=splitk_offsets,
|
||||
reordered_input=reordered_input)
|
||||
|
||||
def _prepare_max_sizes_cpu(self,
|
||||
cuda_graph_lora_params: CudaGraphLoraParams,
|
||||
layer_key: CudaGraphLoraParams.LoraLayerKey,
|
||||
bs: int, input_hidden_size: int):
|
||||
layer_params = cuda_graph_lora_params.get_layer_params(layer_key)
|
||||
shape_2d = (len(self.lora_module_types),
|
||||
cuda_graph_lora_params.max_lora_size
|
||||
) # [num_layer_modules, max_lora_size]
|
||||
shape_3d = shape_2d + (3, )
|
||||
# dummy max sizes, on CPU
|
||||
host_max_in_sizes = torch.empty(
|
||||
shape_3d, dtype=CudaGraphLoraParams.SIZES_DTYPE
|
||||
) # m: batch_size, n: max_lora_rank, k: input_hidden_size
|
||||
host_max_out_sizes = torch.empty_like(
|
||||
host_max_in_sizes
|
||||
) # m: batch_size, n: max_output_hidden_size, k: max_lora_rank
|
||||
host_max_in_sizes[:, :, 0] = bs
|
||||
host_max_in_sizes[:, :, 1] = cuda_graph_lora_params.max_rank
|
||||
host_max_in_sizes[:, :, 2] = input_hidden_size
|
||||
|
||||
host_max_out_sizes[:, :, 0] = bs
|
||||
host_max_out_sizes[:, :, 1] = layer_params.h_output_sizes.unsqueeze(1)
|
||||
host_max_out_sizes[:, :, 2] = cuda_graph_lora_params.max_rank
|
||||
|
||||
return host_max_in_sizes, host_max_out_sizes
|
||||
|
||||
def _forward_cuda_graph_mode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
lora_params: Dict,
|
||||
layer_idx: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Forward pass using CUDA Graph compatible LoRA parameters.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
lora_params: CUDA Graph compatible LoRA parameters
|
||||
layer_idx: Current layer index
|
||||
|
||||
Returns:
|
||||
LoRA output tensor or None
|
||||
"""
|
||||
|
||||
cuda_graph_params: CudaGraphLoraParams = lora_params.get(
|
||||
'cuda_graph_params')
|
||||
# Get layer-specific parameters
|
||||
layer_key = CudaGraphLoraParams.LoraLayerKey(
|
||||
layer_idx=layer_idx, module_ids=tuple(self.lora_module_types))
|
||||
|
||||
if not cuda_graph_params or not cuda_graph_params.layer_info or layer_key not in cuda_graph_params.layer_info:
|
||||
return None
|
||||
|
||||
layer_params = cuda_graph_params.get_layer_params(layer_key)
|
||||
|
||||
# Skip layers that don't have LoRA modules
|
||||
if layer_params is None:
|
||||
return 0 # Pass-through for layers without LoRA modules
|
||||
|
||||
batch_size, hidden_size = x.shape[0], x.shape[-1]
|
||||
num_layer_modules = len(self.lora_module_types)
|
||||
max_rank = cuda_graph_params.max_rank
|
||||
total_output_size = sum(self.output_hidden_sizes)
|
||||
min_kn = min(
|
||||
hidden_size, 8, max_rank
|
||||
) # TODO: hardcode to 8 for now, for alignments in kernels, might have alignment error if rank is less than 8!
|
||||
|
||||
output_buffer = torch.empty(batch_size,
|
||||
total_output_size,
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
|
||||
host_max_in_sizes, host_max_out_sizes = self._prepare_max_sizes_cpu(
|
||||
cuda_graph_params, layer_key, batch_size, hidden_size)
|
||||
|
||||
# Intermediate buffer: [num_layer_modules, batch_size, max_rank]
|
||||
intermediate_buffer = torch.empty(
|
||||
[num_layer_modules, batch_size, max_rank],
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
|
||||
params_fill_input = GroupedGemmParamsInput(
|
||||
x=x,
|
||||
output_buffer=output_buffer,
|
||||
intermediate_buffer=intermediate_buffer,
|
||||
max_lora_size=cuda_graph_params.max_lora_size,
|
||||
max_rank=cuda_graph_params.max_rank,
|
||||
slot_counts=cuda_graph_params.slot_counts,
|
||||
slot_ranks=cuda_graph_params.slot_ranks,
|
||||
slot_offsets_full=cuda_graph_params.slot_offsets_full,
|
||||
b_ptrs=layer_params.d_b_ptrs,
|
||||
b_prime_ptrs=layer_params.d_b_prime_ptrs,
|
||||
sorted_ids=cuda_graph_params.sorted_ids,
|
||||
output_hidden_sizes=layer_params.d_output_sizes,
|
||||
output_sizes_offset=layer_params.d_output_sizes_offset)
|
||||
grouped_gemm_params = self._prepare_grouped_gemm_buffers_fused(
|
||||
params_fill_input)
|
||||
|
||||
torch.ops.trtllm.lora_grouped_gemm_cuda_graph(
|
||||
grouped_gemm_params.in_sizes, grouped_gemm_params.out_sizes,
|
||||
grouped_gemm_params.a_offset, layer_params.d_b_ptrs,
|
||||
grouped_gemm_params.d_offset, layer_params.d_b_prime_ptrs,
|
||||
grouped_gemm_params.d_prime_offset,
|
||||
cuda_graph_params.get_problem_count(layer_key),
|
||||
grouped_gemm_params.lda, grouped_gemm_params.ldb,
|
||||
grouped_gemm_params.ldd, grouped_gemm_params.ldb_prime,
|
||||
grouped_gemm_params.ldd_prime, host_max_in_sizes,
|
||||
host_max_out_sizes, grouped_gemm_params.splitk_offsets,
|
||||
grouped_gemm_params.reordered_input.dtype, min_kn)
|
||||
|
||||
# TODO: move to kernel
|
||||
restored_output = torch.zeros_like(output_buffer)
|
||||
restored_output.index_copy_(0,
|
||||
cuda_graph_params.sorted_ids[:batch_size],
|
||||
output_buffer)
|
||||
return restored_output
|
||||
|
||||
def _forward_eager_mode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
lora_params: Dict,
|
||||
layer_idx: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Eager-mode forward pass using the original LoRA implementation.
|
||||
|
||||
Args:
|
||||
x: Input tensor
|
||||
lora_params: LoRA parameters for eager mode
|
||||
layer_idx: Current layer index
|
||||
|
||||
Returns:
|
||||
LoRA output tensor or None
|
||||
"""
|
||||
lora_ranks = []
|
||||
lora_weight_pointers = []
|
||||
active_lora_module_ids = []
|
||||
|
||||
for module_idx in self.lora_module_types:
|
||||
module_idx = int(module_idx)
|
||||
if module_idx in lora_params[layer_idx]:
|
||||
active_lora_module_ids.append(module_idx)
|
||||
lora_ranks.append(
|
||||
lora_params[layer_idx][module_idx]['adapter_size'])
|
||||
lora_weight_pointers.append(
|
||||
lora_params[layer_idx][module_idx]['weight_pointers'])
|
||||
|
||||
num_seqs = lora_params['num_seqs']
|
||||
|
||||
if len(active_lora_module_ids) == 0:
|
||||
return None
|
||||
else:
|
||||
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
|
||||
x,
|
||||
lora_params['host_request_types'][:num_seqs],
|
||||
lora_ranks,
|
||||
lora_weight_pointers,
|
||||
lora_params['prompt_lens_cpu'][:num_seqs],
|
||||
self.output_hidden_sizes,
|
||||
False, # transA
|
||||
True, # transB
|
||||
max([r.max() for r in lora_ranks]),
|
||||
0,
|
||||
True, # TODO smor- should be lora_params["remove_input_padding"], support in loraOp as well
|
||||
)
|
||||
if isinstance(lora_outputs, torch.Tensor):
|
||||
return lora_outputs
|
||||
else:
|
||||
# For multiple LoRA modules, some might not be executed in grouped gemm.
|
||||
# For those modules not executed, we create zero tensors with matching dimensions.
|
||||
# Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order.
|
||||
lora_output = []
|
||||
for module_idx in self.lora_module_types:
|
||||
if int(module_idx) in active_lora_module_ids:
|
||||
lora_output.append(lora_outputs.pop(0))
|
||||
else:
|
||||
lora_output.append(
|
||||
torch.zeros(list(x.shape[:-1]) + [
|
||||
self.output_hidden_sizes[
|
||||
self.lora_module_types.index(module_idx)]
|
||||
],
|
||||
dtype=x.dtype,
|
||||
device=x.device))
|
||||
lora_output = torch.cat(lora_output, dim=-1)
|
||||
return lora_output
|
||||
|
||||
@ -834,6 +834,8 @@ def create_py_executor_instance(
|
||||
lora_config.lora_target_modules,
|
||||
lora_config.trtllm_modules_to_hf_modules,
|
||||
lora_config.swap_gate_up_proj_lora_b_weight)
|
||||
if isinstance(model_engine, PyTorchModelEngine):
|
||||
model_engine._init_cuda_graph_lora_manager(lora_config)
|
||||
|
||||
resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
|
||||
max_num_sequences)
|
||||
|
||||
@ -533,9 +533,6 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
self.py_decoding_iter = 0
|
||||
self.is_attention_dp_dummy = False
|
||||
self.is_cuda_graph_dummy = False
|
||||
self.py_lora_task_layer_module_configs: list[
|
||||
tensorrt_llm.bindings.internal.runtime.
|
||||
TaskLayerModuleConfig] | None = None
|
||||
self.py_kv_transfer_start_time = None
|
||||
self.py_kv_transfer_timed_out = False
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ import torch._dynamo.config
|
||||
import tensorrt_llm.bindings.internal.userbuffers as ub
|
||||
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
|
||||
torch_dtype_to_str, trace_func)
|
||||
from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig
|
||||
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
|
||||
MultimodalRuntimeData)
|
||||
from tensorrt_llm.inputs.registry import (create_input_processor,
|
||||
@ -44,6 +45,7 @@ from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids
|
||||
from ..models.modeling_utils import DecoderModelForCausalLM
|
||||
from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer,
|
||||
MoeLoadBalancerIterContext)
|
||||
from ..peft.lora.cuda_graph_lora_manager import CudaGraphLoraManager
|
||||
from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
|
||||
get_spec_metadata,
|
||||
update_spec_config_from_model_config)
|
||||
@ -62,7 +64,8 @@ from .layerwise_nvtx_marker import LayerwiseNvtxMarker
|
||||
from .llm_request import LlmRequest, get_draft_token_length
|
||||
from .model_loader import ModelLoader, _construct_checkpoint_loader
|
||||
from .resource_manager import (BaseResourceManager, KVCacheManager,
|
||||
ResourceManager, ResourceManagerType)
|
||||
PeftCacheManager, ResourceManager,
|
||||
ResourceManagerType)
|
||||
from .sampler import SampleStateTensors
|
||||
from .scheduler import ScheduledRequests
|
||||
|
||||
@ -449,6 +452,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
)
|
||||
self.cuda_graph_runner = CUDAGraphRunner(cuda_graph_runner_config)
|
||||
|
||||
# Initialize CUDA Graph LoRA manager if LoRA is enabled
|
||||
self.cuda_graph_lora_manager: Optional[CudaGraphLoraManager] = None
|
||||
|
||||
# Setup the local cache indirection buffer only once and reuse it.
|
||||
# This way it can also be used for CUDA graphs.
|
||||
if self.use_beam_search:
|
||||
@ -493,6 +499,26 @@ class PyTorchModelEngine(ModelEngine):
|
||||
dtype=torch_dtype_to_str(self.model.config.torch_dtype),
|
||||
swap_gate_up_proj_lora_b_weight=swap_gate_up_proj_lora_b_weight)
|
||||
|
||||
def _init_cuda_graph_lora_manager(self, lora_config: LoraConfig):
|
||||
"""Initialize CUDA Graph LoRA manager with model configuration."""
|
||||
# Get model configuration
|
||||
if self.cuda_graph_runner.enabled:
|
||||
max_lora_size = lora_config.max_loras or 8 # Default fallback
|
||||
max_batch_size = self.batch_size # Use engine's max batch size
|
||||
|
||||
self.cuda_graph_lora_manager = CudaGraphLoraManager(
|
||||
max_lora_size=max_lora_size,
|
||||
max_batch_size=max_batch_size,
|
||||
max_lora_rank=lora_config.max_lora_rank,
|
||||
model=self.model,
|
||||
lora_model_config=self.lora_model_config,
|
||||
device='cuda')
|
||||
|
||||
logger.info(
|
||||
f"Initialized CUDA Graph LoRA manager, "
|
||||
f"max {max_lora_size} adapters, max rank {lora_config.max_lora_rank}"
|
||||
)
|
||||
|
||||
def set_guided_decoder(self,
|
||||
guided_decoder: CapturableGuidedDecoder) -> bool:
|
||||
if hasattr(self.model, "set_guided_decoder"):
|
||||
@ -1936,7 +1962,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
cache_indirection_buffer: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens_device: Optional[torch.Tensor] = None,
|
||||
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None,
|
||||
resource_manager: Optional[ResourceManager] = None):
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
maybe_graph: bool = False):
|
||||
"""
|
||||
Prepare inputs for Pytorch Model.
|
||||
"""
|
||||
@ -2673,8 +2700,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
attn_metadata.prepare()
|
||||
|
||||
peft_cache_manager = resource_manager and resource_manager.get_resource_manager(
|
||||
ResourceManagerType.PEFT_CACHE_MANAGER)
|
||||
lora_params = self._get_lora_params_from_requests(
|
||||
scheduled_requests, attn_metadata)
|
||||
scheduled_requests, attn_metadata, peft_cache_manager, maybe_graph)
|
||||
|
||||
attn_all_rank_num_tokens = self._get_all_rank_num_tokens(attn_metadata)
|
||||
padded_num_tokens, can_run_piecewise_cuda_graph, attn_all_rank_num_tokens = self._get_padding_params(
|
||||
@ -3142,10 +3171,41 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'inputs_embeds': None
|
||||
}, gather_ids if is_spec_decode else None
|
||||
|
||||
def _get_lora_params_from_requests(self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
attn_metadata: AttentionMetadata):
|
||||
def _get_lora_params_from_requests(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
attn_metadata: AttentionMetadata,
|
||||
peft_cache_manager: Optional[PeftCacheManager] = None,
|
||||
maybe_graph: bool = False):
|
||||
'''
|
||||
Get LoRA parameters from scheduled requests.
|
||||
|
||||
Uses CUDA Graph compatible mode in decode only batch, otherwise falls back to eager mode.
|
||||
|
||||
Returns:
|
||||
Dictionary containing LoRA parameters, or None if no LoRA requests
|
||||
'''
|
||||
use_cuda_graph_mode = self.cuda_graph_lora_manager is not None and maybe_graph
|
||||
|
||||
if use_cuda_graph_mode:
|
||||
return self.cuda_graph_lora_manager.prepare_cuda_graph_lora_params(
|
||||
scheduled_requests, attn_metadata, peft_cache_manager)
|
||||
else:
|
||||
if self.cuda_graph_lora_manager is not None:
|
||||
self.cuda_graph_lora_manager.adapter_slot_manager.remove_evicted_slots_in_cpp(
|
||||
peft_cache_manager)
|
||||
peft_table = peft_cache_manager.get_and_reset_batch_peft_table(
|
||||
) if peft_cache_manager is not None else None
|
||||
return peft_table and self._get_eager_lora_params_from_requests(
|
||||
scheduled_requests, attn_metadata, peft_table)
|
||||
|
||||
def _get_eager_lora_params_from_requests(
|
||||
self, scheduled_requests: ScheduledRequests,
|
||||
attn_metadata: AttentionMetadata,
|
||||
peft_table: Dict[int, list[TaskLayerModuleConfig]]):
|
||||
'''
|
||||
Eager mode LoRA parameter preparation logic.
|
||||
|
||||
lora_params: dict
|
||||
{
|
||||
layer_id: dict
|
||||
@ -3165,10 +3225,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
# trace all requests to get the union set of the lora params
|
||||
for request in request_list:
|
||||
if request.py_lora_task_layer_module_configs is None:
|
||||
if request.lora_task_id is None:
|
||||
continue
|
||||
|
||||
for module in request.py_lora_task_layer_module_configs:
|
||||
layer_module_configs = peft_table[request.lora_task_id]
|
||||
|
||||
for module in layer_module_configs:
|
||||
module_id = module.module_id
|
||||
layer_id = module.layer_id
|
||||
|
||||
@ -3195,7 +3257,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
for request in request_list:
|
||||
# Need to set default values for this case
|
||||
if request.py_lora_task_layer_module_configs is None:
|
||||
if request.lora_task_id is None:
|
||||
for layer_id in lora_params:
|
||||
for module_id in lora_params[layer_id]:
|
||||
current_lora_params = lora_params[layer_id][module_id]
|
||||
@ -3245,7 +3307,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
cache_indirection_buffer: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens_device: Optional[torch.Tensor] = None,
|
||||
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None,
|
||||
resource_manager: Optional[ResourceManager] = None):
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
maybe_graph: bool = False):
|
||||
if self.mapping is not None and 'cp_type' in self.mapping.cp_config:
|
||||
cp_type = self.mapping.cp_config['cp_type']
|
||||
if CpType.STAR == cp_type:
|
||||
@ -3258,12 +3321,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
raise NotImplementedError(
|
||||
f"Unsupported cp_type {getattr(cp_type, 'name', cp_type)}.")
|
||||
|
||||
return self._prepare_tp_inputs(scheduled_requests, kv_cache_manager,
|
||||
attn_metadata, spec_metadata,
|
||||
new_tensors_device,
|
||||
cache_indirection_buffer,
|
||||
num_accepted_tokens_device,
|
||||
req_id_to_old_request, resource_manager)
|
||||
return self._prepare_tp_inputs(
|
||||
scheduled_requests, kv_cache_manager, attn_metadata, spec_metadata,
|
||||
new_tensors_device, cache_indirection_buffer,
|
||||
num_accepted_tokens_device, req_id_to_old_request, resource_manager,
|
||||
maybe_graph)
|
||||
|
||||
@torch.inference_mode()
|
||||
@with_model_extra_attrs(lambda self: self.model.extra_attrs)
|
||||
@ -3347,12 +3409,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_metadata = self.spec_metadata
|
||||
else:
|
||||
spec_metadata = None
|
||||
|
||||
inputs, gather_ids = self._prepare_inputs(
|
||||
padded_requests, kv_cache_manager, attn_metadata, spec_metadata,
|
||||
new_tensors_device, cache_indirection_buffer,
|
||||
num_accepted_tokens_device, req_id_to_old_request,
|
||||
resource_manager)
|
||||
resource_manager, can_run_graph)
|
||||
|
||||
with with_shared_pool(self.cuda_graph_runner.get_graph_pool()):
|
||||
if not can_run_graph:
|
||||
|
||||
@ -11,6 +11,7 @@ import torch
|
||||
import tensorrt_llm
|
||||
import tensorrt_llm.bindings
|
||||
from tensorrt_llm._torch.distributed.communicator import Distributed, ReduceOp
|
||||
from tensorrt_llm.bindings.internal.runtime import TaskLayerModuleConfig
|
||||
from tensorrt_llm.llmapi.llm_args import (KvCacheConfig, PeftCacheConfig,
|
||||
PybindMirror)
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
@ -1554,6 +1555,9 @@ class PeftCacheManager(BaseResourceManager):
|
||||
model_config=ModelConfigPython.from_model_config_cpp(model_config),
|
||||
cpp_peft_cache_manager=self.impl)
|
||||
|
||||
self._batch_peft_table: Optional[Dict[int, list[
|
||||
TaskLayerModuleConfig]]] = None # task_id -> layer-module-configs mapping for the current batch
|
||||
|
||||
def get_lora_manager(self) -> LoraManager:
|
||||
return self._lora_manager
|
||||
|
||||
@ -1600,18 +1604,9 @@ class PeftCacheManager(BaseResourceManager):
|
||||
for req in context_batch:
|
||||
self.add_request_peft(req)
|
||||
|
||||
py_lora_task_layer_module_configs = self.impl.ensure_batch(
|
||||
self._batch_peft_table, _ = self.impl.ensure_batch_map_task_id(
|
||||
context_batch, generation_batch, False)
|
||||
|
||||
for req in context_batch:
|
||||
req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[
|
||||
req.
|
||||
py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None
|
||||
for req in generation_batch:
|
||||
req.py_lora_task_layer_module_configs = py_lora_task_layer_module_configs[
|
||||
req.
|
||||
py_request_id] if req.py_request_id in py_lora_task_layer_module_configs else None
|
||||
|
||||
def update_resources(self, scheduled_batch: ScheduledRequests):
|
||||
pass
|
||||
|
||||
@ -1620,3 +1615,12 @@ class PeftCacheManager(BaseResourceManager):
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def get_and_reset_batch_peft_table(
|
||||
self) -> Dict[int, list[TaskLayerModuleConfig]]:
|
||||
batch_peft_table = self._batch_peft_table
|
||||
self._batch_peft_table = None
|
||||
return batch_peft_table
|
||||
|
||||
def is_task_cached_device(self, task_id: int) -> bool:
|
||||
return self.impl.is_task_cached_device(task_id)
|
||||
|
||||
@ -59,6 +59,8 @@ def test_register_fake(custom_ops):
|
||||
# TODO: add fake impl for these ops in follow-up PRs.
|
||||
to_fix = {
|
||||
"trtllm::lora_grouped_gemm",
|
||||
"trtllm::lora_grouped_gemm_cuda_graph",
|
||||
"trtllm::lora_group_gemm_param_fill_row_reorder_fusion",
|
||||
"trtllm::mtp_relaxed_acceptance_op",
|
||||
"trtllm::mtp_update_hidden_states_op",
|
||||
"trtllm::mtp_prepare_drafter_inputs_op",
|
||||
|
||||
@ -1,18 +1,28 @@
|
||||
import json
|
||||
import tarfile
|
||||
import tempfile
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, OrderedDict, Type
|
||||
from typing import List, Optional, OrderedDict, Tuple, Type, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import duplicate_list_to_length, flatten_list, similar
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch.peft.lora.cuda_graph_lora_params import \
|
||||
CudaGraphLoraParams
|
||||
from tensorrt_llm._torch.peft.lora.layer import (GroupedGemmParamsInput,
|
||||
GroupedGemmParamsOutput,
|
||||
LoraLayer)
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.llmapi.llm import BaseLLM
|
||||
from tensorrt_llm.llmapi.llm_args import CudaGraphConfig
|
||||
from tensorrt_llm.lora_helper import LoraConfig
|
||||
|
||||
from .test_utils import DelayedAssert
|
||||
|
||||
_RU_LORA_ADAPTER_PROMPTS = [
|
||||
"Назови главную площадь в центре Москвы.",
|
||||
"Напиши полное предложение, описывающее, что в музее не хватает женских скульптур. Используй фразу \"не хватает\".",
|
||||
@ -284,3 +294,233 @@ def create_mock_nemo_lora_checkpoint(
|
||||
tar.add(config_path, arcname="model_config.yaml")
|
||||
|
||||
return nemo_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class CUDAGraphLoRATestParams:
|
||||
batch_slot_ids: List[int]
|
||||
input_hidden_size: int
|
||||
slot_ranks: List[int]
|
||||
max_lora_rank: int
|
||||
output_hidden_sizes: List[int]
|
||||
layer_module_mask: Optional[Union[torch.Tensor, bool]]
|
||||
dtype: torch.dtype
|
||||
seed: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.layer_module_mask is None or isinstance(
|
||||
self.layer_module_mask,
|
||||
bool) or self.layer_module_mask.shape == (self.module_count,
|
||||
self.slot_count)
|
||||
assert all(0 <= idx <= self.slot_count for idx in self.batch_slot_ids)
|
||||
assert all(0 <= rank <= self.max_lora_rank for rank in self.slot_ranks)
|
||||
if isinstance(self.layer_module_mask, torch.Tensor):
|
||||
self.layer_module_mask = self.layer_module_mask.to(dtype=torch.bool)
|
||||
elif self.layer_module_mask is not None:
|
||||
self.layer_module_mask = bool(self.layer_module_mask)
|
||||
else:
|
||||
self.layer_module_mask = True
|
||||
|
||||
@property
|
||||
def module_count(self):
|
||||
return len(self.output_hidden_sizes)
|
||||
|
||||
@property
|
||||
def slot_count(self):
|
||||
return len(self.slot_ranks)
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return len(self.batch_slot_ids)
|
||||
|
||||
@property
|
||||
def sum_output_hidden_size(self):
|
||||
return sum(self.output_hidden_sizes)
|
||||
|
||||
|
||||
def create_grouped_gemm_params_filler_input(
|
||||
test_params: Optional[CUDAGraphLoRATestParams] = None
|
||||
) -> Tuple[GroupedGemmParamsInput, LoraLayer]:
|
||||
if test_params is None:
|
||||
test_params = CUDAGraphLoRATestParams(
|
||||
batch_slot_ids=[0, 3, 3, 4, 5, 8],
|
||||
input_hidden_size=4096,
|
||||
slot_ranks=[8, 12, 4, 3] * 2,
|
||||
max_lora_rank=64,
|
||||
output_hidden_sizes=[4096, 4096],
|
||||
layer_module_mask=None,
|
||||
dtype=torch.bfloat16,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
with torch.random.fork_rng():
|
||||
torch.manual_seed(test_params.seed)
|
||||
shape_2d = (test_params.module_count, test_params.slot_count)
|
||||
|
||||
x = torch.randn(test_params.batch_size,
|
||||
test_params.input_hidden_size,
|
||||
dtype=test_params.dtype,
|
||||
device="cuda")
|
||||
output_buffer = torch.randn(test_params.batch_size,
|
||||
test_params.sum_output_hidden_size,
|
||||
dtype=test_params.dtype,
|
||||
device="cuda")
|
||||
b_ptrs = torch.randint(1,
|
||||
1000000,
|
||||
shape_2d,
|
||||
dtype=CudaGraphLoraParams.PTR_DTYPE)
|
||||
b_prime_ptrs = torch.randint(1,
|
||||
1000000,
|
||||
shape_2d,
|
||||
dtype=CudaGraphLoraParams.PTR_DTYPE)
|
||||
|
||||
b_ptrs *= test_params.layer_module_mask
|
||||
b_prime_ptrs *= test_params.layer_module_mask
|
||||
|
||||
b_ptrs = b_ptrs.to(device="cuda")
|
||||
b_prime_ptrs = b_prime_ptrs.to(device="cuda")
|
||||
slot_ranks = torch.tensor(test_params.slot_ranks,
|
||||
dtype=CudaGraphLoraParams.SIZES_DTYPE,
|
||||
device="cuda")
|
||||
|
||||
intermediate_buffer = torch.randn(test_params.module_count,
|
||||
test_params.batch_size,
|
||||
test_params.max_lora_rank,
|
||||
dtype=test_params.dtype,
|
||||
device="cuda")
|
||||
slot_counts = CudaGraphLoraParams.get_slot_counts(
|
||||
test_params.batch_slot_ids, test_params.slot_count)
|
||||
slot_offsets_full = CudaGraphLoraParams.get_offset_from_counts(
|
||||
slot_counts, full=True)
|
||||
sorted_ids = CudaGraphLoraParams.get_sorted_indices(
|
||||
test_params.batch_slot_ids)
|
||||
|
||||
slot_offsets_full = slot_offsets_full.to(device="cuda",
|
||||
dtype=torch.int64)
|
||||
slot_counts = slot_counts.to(device="cuda", dtype=torch.int32)
|
||||
sorted_ids = sorted_ids.to(device="cuda", dtype=torch.int64)
|
||||
|
||||
output_hidden_sizes = torch.tensor(
|
||||
test_params.output_hidden_sizes,
|
||||
dtype=CudaGraphLoraParams.SIZES_DTYPE,
|
||||
device="cuda")
|
||||
output_sizes_offset = CudaGraphLoraParams.get_offset_from_counts(
|
||||
output_hidden_sizes).to(dtype=CudaGraphLoraParams.PTR_DTYPE,
|
||||
device="cuda")
|
||||
|
||||
layer = LoraLayer([0] * test_params.module_count,
|
||||
test_params.output_hidden_sizes)
|
||||
inputs = GroupedGemmParamsInput(
|
||||
x=x,
|
||||
output_buffer=output_buffer,
|
||||
intermediate_buffer=intermediate_buffer,
|
||||
max_lora_size=test_params.slot_count,
|
||||
max_rank=test_params.max_lora_rank,
|
||||
slot_counts=slot_counts,
|
||||
slot_ranks=slot_ranks,
|
||||
slot_offsets_full=slot_offsets_full,
|
||||
sorted_ids=sorted_ids,
|
||||
b_ptrs=b_ptrs,
|
||||
b_prime_ptrs=b_prime_ptrs,
|
||||
output_hidden_sizes=output_hidden_sizes,
|
||||
output_sizes_offset=output_sizes_offset,
|
||||
)
|
||||
return inputs, layer
|
||||
|
||||
|
||||
def compare_grouped_gemm_params(
|
||||
params: GroupedGemmParamsOutput,
|
||||
ref: GroupedGemmParamsOutput,
|
||||
params_input: GroupedGemmParamsInput,
|
||||
params_to_store_msg: List[str] | None = ['splitk_offsets'],
|
||||
params_exclude_msg: List[str] | None = None,
|
||||
):
|
||||
assert not (params_to_store_msg and params_exclude_msg)
|
||||
|
||||
bs, input_hidden_size = params.reordered_input.shape
|
||||
asserter = DelayedAssert()
|
||||
params_dict = asdict(params)
|
||||
ref_dict = asdict(ref)
|
||||
|
||||
if not params_to_store_msg:
|
||||
params_to_store_msg = set(params_dict.keys())
|
||||
if params_exclude_msg:
|
||||
for name in params_exclude_msg:
|
||||
params_to_store_msg.discard(name)
|
||||
|
||||
def get_msg(name: str, v: torch.Tensor, ref_v: torch.Tensor):
|
||||
is_get_msg = any(p in name or name in p for p in params_to_store_msg)
|
||||
header = f"\n\n{name=}\n"
|
||||
return f"{header} {v=}\n {ref_v=}\n diff:\n{v - ref_v}" if is_get_msg else header
|
||||
|
||||
for name in params_dict.keys():
|
||||
v = params_dict[name]
|
||||
ref_v = ref_dict[name]
|
||||
if name not in ("reordered_input", "a_offset"):
|
||||
asserter.add(
|
||||
v.allclose(ref_v),
|
||||
get_msg(name, v, ref_v),
|
||||
)
|
||||
|
||||
# Test a_offset separately
|
||||
offset = params.a_offset - params.reordered_input.data_ptr()
|
||||
ref_offset = ref.a_offset - ref.reordered_input.data_ptr()
|
||||
asserter.add(
|
||||
(offset == ref_offset).all(),
|
||||
# 'a_offset_fused',
|
||||
get_msg("a_offset", offset, ref_offset))
|
||||
|
||||
# Test reordered_input separately
|
||||
valid_row = params_input.slot_offsets_full[-1].cpu().item()
|
||||
valid_rows = params.reordered_input[:valid_row]
|
||||
ref_valid_rows = ref.reordered_input[:valid_row]
|
||||
asserter.add(
|
||||
valid_rows.allclose(ref_valid_rows),
|
||||
get_msg(f"valid part({valid_row=}, {bs=}) of reordered_input",
|
||||
valid_rows, ref_valid_rows))
|
||||
|
||||
# check intermediate buffer and output buffer are all zeros
|
||||
asserter.add(
|
||||
torch.all(params_input.intermediate_buffer == 0),
|
||||
get_msg("intermediate buffer", params_input.intermediate_buffer, 0))
|
||||
asserter.add(torch.all(params_input.output_buffer == 0),
|
||||
get_msg("output buffer", params_input.output_buffer, 0))
|
||||
|
||||
if valid_row < bs:
|
||||
invalid_rows = params.reordered_input[valid_row:]
|
||||
ref_invalid_rows = ref.reordered_input[valid_row:]
|
||||
asserter.add(
|
||||
torch.all(invalid_rows == 0),
|
||||
get_msg("invalid part of reordered_input", invalid_rows,
|
||||
ref_invalid_rows))
|
||||
else:
|
||||
asserter.add(
|
||||
True,
|
||||
f"valid_row is full {valid_row=} v. bs: {params_dict['reordered_input'].shape[0]=}"
|
||||
)
|
||||
asserter.assert_all()
|
||||
|
||||
|
||||
def compare_cuda_graph_lora_params_filler(test_params: CUDAGraphLoRATestParams):
|
||||
grouped_gemm_params_filler_input, layer = create_grouped_gemm_params_filler_input(
|
||||
test_params)
|
||||
output_fused = layer._prepare_grouped_gemm_buffers_fused(
|
||||
grouped_gemm_params_filler_input)
|
||||
|
||||
assert torch.all(
|
||||
grouped_gemm_params_filler_input.intermediate_buffer == 0
|
||||
), f"intermediate_buffer is not all zeros: {grouped_gemm_params_filler_input.intermediate_buffer}; non zero / zeros: {(grouped_gemm_params_filler_input.intermediate_buffer != 0).sum()} / {grouped_gemm_params_filler_input.intermediate_buffer.numel()}"
|
||||
assert torch.all(
|
||||
grouped_gemm_params_filler_input.output_buffer == 0
|
||||
), f"output_buffer is not all zeros: {grouped_gemm_params_filler_input.output_buffer}; non zero / zeros: {(grouped_gemm_params_filler_input.output_buffer != 0).sum()} / {grouped_gemm_params_filler_input.output_buffer.numel()}"
|
||||
|
||||
output_pytorch = layer.prepare_grouped_gemm_buffers(
|
||||
grouped_gemm_params_filler_input)
|
||||
compare_grouped_gemm_params(output_fused,
|
||||
output_pytorch,
|
||||
grouped_gemm_params_filler_input,
|
||||
params_to_store_msg=None)
|
||||
|
||||
|
||||
test_lora_with_and_without_cuda_graph = pytest.mark.parametrize(
|
||||
"cuda_graph_config", [CudaGraphConfig(max_batch_size=10), None])
|
||||
|
||||
@ -9,7 +9,8 @@ from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
from .lora_test_utils import (
|
||||
check_llama_7b_multi_lora_from_request_test_harness,
|
||||
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1)
|
||||
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1,
|
||||
test_lora_with_and_without_cuda_graph)
|
||||
from .test_llm import (_test_llm_capture_request_error, llama_model_path,
|
||||
llm_get_stats_test_harness,
|
||||
llm_return_logprobs_test_harness,
|
||||
@ -46,9 +47,10 @@ def test_llama_7b_lora_tp2():
|
||||
kv_cache_config=global_kv_cache_config)
|
||||
|
||||
|
||||
@pytest.mark.gpu2
|
||||
@pytest.mark.skip(reason="https://nvbugs/5682551")
|
||||
def test_llama_7b_multi_lora_tp2():
|
||||
@pytest.mark.gpu4
|
||||
@skip_ray # https://nvbugs/5682551
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_llama_7b_multi_lora_tp4(cuda_graph_config):
|
||||
# For LoRA checkpoints without finetuned embedding and lm_head, we can either:
|
||||
# (1) specify lora_target_modules, or
|
||||
# (2) provide a lora_dir to infer the lora_target_modules.
|
||||
@ -59,21 +61,18 @@ def test_llama_7b_multi_lora_tp2():
|
||||
check_llama_7b_multi_lora_from_request_test_harness(
|
||||
LLM,
|
||||
lora_config=lora_config,
|
||||
tensor_parallel_size=2,
|
||||
tensor_parallel_size=4,
|
||||
kv_cache_config=global_kv_cache_config,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
|
||||
@skip_ray # https://nvbugs/5727075
|
||||
@pytest.mark.gpu2
|
||||
def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None:
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1(
|
||||
cuda_graph_config) -> None:
|
||||
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1(
|
||||
LLM,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
LLM, cuda_graph_config=cuda_graph_config)
|
||||
|
||||
|
||||
@skip_ray
|
||||
|
||||
@ -20,7 +20,8 @@ from tensorrt_llm.sampling_params import SamplingParams
|
||||
from .lora_test_utils import (
|
||||
check_llama_7b_multi_lora_from_request_test_harness,
|
||||
check_llama_7b_multi_unique_lora_adapters_from_request,
|
||||
create_mock_nemo_lora_checkpoint)
|
||||
create_mock_nemo_lora_checkpoint, compare_cuda_graph_lora_params_filler,
|
||||
CUDAGraphLoRATestParams, test_lora_with_and_without_cuda_graph)
|
||||
from .test_llm import (_test_llm_capture_request_error, get_model_path,
|
||||
global_kvcache_config, llama_model_path,
|
||||
llm_get_stats_async_test_harness,
|
||||
@ -42,6 +43,7 @@ import torch
|
||||
from peft import LoraConfig as PeftLoraConfig
|
||||
from peft import get_peft_model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from dataclasses import replace
|
||||
|
||||
# isort: on
|
||||
|
||||
@ -283,19 +285,76 @@ def test_embedding_bias_with_torch_sampler_strategies():
|
||||
)
|
||||
|
||||
|
||||
def test_lora_cuda_graph_params_filling_kernel_special_cases():
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# test all requests have the same LoRA id case
|
||||
test_params = CUDAGraphLoRATestParams(
|
||||
batch_slot_ids=[0] * 10,
|
||||
input_hidden_size=4096,
|
||||
slot_ranks=[64] * 10,
|
||||
max_lora_rank=64,
|
||||
output_hidden_sizes=[123],
|
||||
layer_module_mask=None,
|
||||
dtype=torch.bfloat16,
|
||||
seed=42,
|
||||
)
|
||||
compare_cuda_graph_lora_params_filler(test_params)
|
||||
|
||||
# test no LoRA in a batch case
|
||||
test_params2 = replace(test_params,
|
||||
batch_slot_ids=[len(test_params.slot_ranks)] * 10)
|
||||
compare_cuda_graph_lora_params_filler(test_params2)
|
||||
|
||||
# test all having three modules case
|
||||
test_params3 = replace(test_params, output_hidden_sizes=[123, 456, 789])
|
||||
compare_cuda_graph_lora_params_filler(test_params3)
|
||||
|
||||
# test some layer module have invalid weight pointers case
|
||||
mask = torch.full((test_params3.module_count, test_params3.slot_count),
|
||||
True,
|
||||
dtype=torch.bool)
|
||||
mask[0, 0] = False
|
||||
mask[1, 7] = False
|
||||
mask[2, 3] = False
|
||||
test_params4 = replace(test_params3, layer_module_mask=mask)
|
||||
compare_cuda_graph_lora_params_filler(test_params4)
|
||||
|
||||
# test mixed slot ids case
|
||||
test_params5 = CUDAGraphLoRATestParams(
|
||||
batch_slot_ids=[6, 2, 0, 1, 1, 1, 5, 6],
|
||||
input_hidden_size=512,
|
||||
slot_ranks=[8, 12, 4] * 2,
|
||||
max_lora_rank=15,
|
||||
output_hidden_sizes=[123, 456, 789],
|
||||
layer_module_mask=None,
|
||||
dtype=torch.bfloat16,
|
||||
seed=42,
|
||||
)
|
||||
compare_cuda_graph_lora_params_filler(test_params5)
|
||||
|
||||
# test mixed slot with invalid weight pointers
|
||||
mask = torch.full((test_params5.module_count, test_params5.slot_count),
|
||||
True,
|
||||
dtype=torch.bool)
|
||||
mask[1, 3] = False
|
||||
mask[2, 5] = False
|
||||
mask[-1, -4] = False
|
||||
test_params6 = replace(test_params5, layer_module_mask=mask)
|
||||
compare_cuda_graph_lora_params_filler(test_params6)
|
||||
|
||||
|
||||
def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None:
|
||||
lora_config = LoraConfig(
|
||||
lora_dir=[f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"],
|
||||
max_lora_rank=8,
|
||||
max_loras=2,
|
||||
max_cpu_loras=2)
|
||||
llm = LLM(
|
||||
model=f"{llm_models_root()}/llama-models/llama-7b-hf",
|
||||
lora_config=lora_config,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None,
|
||||
**llm_kwargs)
|
||||
if "cuda_graph_config" not in llm_kwargs:
|
||||
llm_kwargs["cuda_graph_config"] = None
|
||||
llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf",
|
||||
lora_config=lora_config,
|
||||
**llm_kwargs)
|
||||
try:
|
||||
prompts = [
|
||||
"美国的首都在哪里? \n答案:",
|
||||
@ -318,22 +377,21 @@ def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None:
|
||||
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
@pytest.mark.part0
|
||||
def test_llama_7b_lora():
|
||||
llama_7b_lora_from_dir_test_harness()
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_llama_7b_lora(cuda_graph_config):
|
||||
llama_7b_lora_from_dir_test_harness(cuda_graph_config=cuda_graph_config)
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
def test_llama_7b_lora_default_modules() -> None:
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_llama_7b_lora_default_modules(cuda_graph_config) -> None:
|
||||
lora_config = LoraConfig(max_lora_rank=64, max_loras=2, max_cpu_loras=2)
|
||||
|
||||
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
|
||||
|
||||
llm = LLM(
|
||||
model=hf_model_dir,
|
||||
lora_config=lora_config,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
llm = LLM(model=hf_model_dir,
|
||||
lora_config=lora_config,
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
hf_lora_dir = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"
|
||||
try:
|
||||
@ -359,7 +417,8 @@ def test_llama_7b_lora_default_modules() -> None:
|
||||
|
||||
def _check_llama_7b_multi_lora_evict_load_new_adapters(
|
||||
lora_adapter_count_per_call: list[int], max_loras: int,
|
||||
max_cpu_loras: int, repeat_calls: int, repeats_per_call: int):
|
||||
max_cpu_loras: int, repeat_calls: int, repeats_per_call: int,
|
||||
**llm_kwargs):
|
||||
# For LoRA checkpoints without finetuned embedding and lm_head, we can either:
|
||||
# (1) specify lora_target_modules, or
|
||||
# (2) provide a lora_dir to infer the lora_target_modules.
|
||||
@ -373,15 +432,14 @@ def _check_llama_7b_multi_lora_evict_load_new_adapters(
|
||||
repeats_per_call,
|
||||
LLM,
|
||||
lora_config=lora_config,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
**llm_kwargs)
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
@skip_ray # https://nvbugs/5682551
|
||||
@pytest.mark.part3
|
||||
def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache():
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache(cuda_graph_config):
|
||||
"""Test eviction and re-loading a previously evicted adapter from the LoRA GPU cache, within a single
|
||||
llm.generate call, that's repeated twice.
|
||||
""" # noqa: D205
|
||||
@ -390,12 +448,15 @@ def test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache():
|
||||
max_loras=1,
|
||||
max_cpu_loras=2,
|
||||
repeat_calls=2,
|
||||
repeats_per_call=3)
|
||||
repeats_per_call=3,
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
@pytest.mark.part1
|
||||
def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache():
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache(
|
||||
cuda_graph_config):
|
||||
"""Test eviction and loading of new adapters in the evicted space, over several llm.generate calls, with LoRA GPU
|
||||
cache size < LoRA CPU cache size.
|
||||
""" # noqa: D205
|
||||
@ -404,25 +465,29 @@ def test_llama_7b_multi_lora_evict_and_load_new_adapters_in_cpu_and_gpu_cache():
|
||||
max_loras=1,
|
||||
max_cpu_loras=3,
|
||||
repeat_calls=1,
|
||||
repeats_per_call=1)
|
||||
repeats_per_call=1,
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
@pytest.mark.part0
|
||||
def test_llama_7b_multi_lora_read_from_cache_after_insert():
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_llama_7b_multi_lora_read_from_cache_after_insert(cuda_graph_config):
|
||||
"""Test that loading and then using the same adapters loaded in cache works."""
|
||||
_check_llama_7b_multi_lora_evict_load_new_adapters(
|
||||
lora_adapter_count_per_call=[3],
|
||||
max_loras=3,
|
||||
max_cpu_loras=3,
|
||||
repeat_calls=2,
|
||||
repeats_per_call=1)
|
||||
repeats_per_call=1,
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
@pytest.mark.part3
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_cache(
|
||||
):
|
||||
cuda_graph_config):
|
||||
"""Test eviction, reloading new adapters and reloading previously evicted adapters from the LoRA CPU cache & GPU
|
||||
cache over multiple llm.generate call repeated twice (two calls with the same requests):
|
||||
At the end of the 1st llm.generate call:
|
||||
@ -439,12 +504,14 @@ def test_llama_7b_multi_lora_evict_and_reload_evicted_adapters_in_cpu_and_gpu_ca
|
||||
max_loras=2,
|
||||
max_cpu_loras=2,
|
||||
repeat_calls=2,
|
||||
repeats_per_call=1)
|
||||
repeats_per_call=1,
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
@pytest.mark.part2
|
||||
def test_llama_7b_peft_cache_config_affects_peft_cache_size():
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_llama_7b_peft_cache_config_affects_peft_cache_size(cuda_graph_config):
|
||||
"""Tests that LLM arg of peft_cache_config affects the peft cache sizes.
|
||||
|
||||
NOTE: The caller can't get the actual LoRA cache sizes, so we instead we
|
||||
@ -464,9 +531,7 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size():
|
||||
lora_config=lora_config_no_cache_size_values,
|
||||
peft_cache_config=PeftCacheConfig(
|
||||
host_cache_size=1), # size in bytes
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
# Test that too small PeftCacheConfig.device_cache_percent causes failure
|
||||
with pytest.raises(RuntimeError):
|
||||
@ -474,15 +539,14 @@ def test_llama_7b_peft_cache_config_affects_peft_cache_size():
|
||||
LLM,
|
||||
lora_config=lora_config_no_cache_size_values,
|
||||
peft_cache_config=PeftCacheConfig(device_cache_percent=0.0000001),
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
|
||||
@skip_ray # https://nvbugs/5682551
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
@pytest.mark.part1
|
||||
def test_llama_7b_lora_config_overrides_peft_cache_config():
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_llama_7b_lora_config_overrides_peft_cache_config(cuda_graph_config):
|
||||
"""Tests that cache size args in lora_config LLM arg override the cache size
|
||||
parameters in peft_cache_config LLM arg.
|
||||
""" # noqa: D205
|
||||
@ -496,9 +560,7 @@ def test_llama_7b_lora_config_overrides_peft_cache_config():
|
||||
peft_cache_config=PeftCacheConfig(
|
||||
host_cache_size=1, # size in bytes
|
||||
device_cache_percent=0.0000001),
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
|
||||
# TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high
|
||||
@ -506,7 +568,8 @@ def test_llama_7b_lora_config_overrides_peft_cache_config():
|
||||
@pytest.mark.skip(reason="https://nvbugs/5448464")
|
||||
@skip_gpu_memory_less_than_138gb
|
||||
@pytest.mark.part1
|
||||
def test_nemotron_nas_lora() -> None:
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_nemotron_nas_lora(cuda_graph_config) -> None:
|
||||
lora_config = LoraConfig(lora_dir=[
|
||||
f"{llm_models_root()}/nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1-lora-adapter_r64"
|
||||
],
|
||||
@ -518,7 +581,7 @@ def test_nemotron_nas_lora() -> None:
|
||||
model=
|
||||
f"{llm_models_root()}/nemotron-nas/Llama-3_3-Nemotron-Super-49B-v1",
|
||||
lora_config=lora_config,
|
||||
)
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you?",
|
||||
@ -539,7 +602,8 @@ def test_nemotron_nas_lora() -> None:
|
||||
|
||||
@skip_gpu_memory_less_than_80gb
|
||||
@pytest.mark.part0
|
||||
def test_llama_3_1_8b_fp8_with_bf16_lora() -> None:
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_llama_3_1_8b_fp8_with_bf16_lora(cuda_graph_config) -> None:
|
||||
skip_fp8_pre_ada(use_fp8=True)
|
||||
model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
|
||||
lora_dir = f"{llm_models_root()}/lora/llama-3-chinese-8b-instruct-v2-lora"
|
||||
@ -552,12 +616,9 @@ def test_llama_3_1_8b_fp8_with_bf16_lora() -> None:
|
||||
max_cpu_loras=2)
|
||||
lora_req = LoRARequest("lora-chinese", 0, lora_dir)
|
||||
|
||||
llm = LLM(
|
||||
model_dir,
|
||||
lora_config=lora_config,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
llm = LLM(model_dir,
|
||||
lora_config=lora_config,
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
try:
|
||||
output = llm.generate(prompt,
|
||||
@ -570,7 +631,8 @@ def test_llama_3_1_8b_fp8_with_bf16_lora() -> None:
|
||||
|
||||
@skip_gpu_memory_less_than_80gb
|
||||
@pytest.mark.part2
|
||||
def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_bielik_11b_v2_2_instruct_multi_lora(cuda_graph_config) -> None:
|
||||
model_dir = f"{llm_models_root()}/Bielik-11B-v2.2-Instruct"
|
||||
|
||||
target_modules = ['attn_q', 'attn_k', 'attn_v']
|
||||
@ -600,12 +662,9 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
|
||||
max_lora_rank=8,
|
||||
max_loras=2,
|
||||
max_cpu_loras=2)
|
||||
llm = LLM(
|
||||
model_dir,
|
||||
lora_config=trtllm_lora_config,
|
||||
# Disable CUDA graph
|
||||
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
|
||||
cuda_graph_config=None)
|
||||
llm = LLM(model_dir,
|
||||
lora_config=trtllm_lora_config,
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
prompts = [
|
||||
"Kim był Mikołaj Kopernik i z czego zasłynął?",
|
||||
@ -624,7 +683,8 @@ def test_bielik_11b_v2_2_instruct_multi_lora() -> None:
|
||||
|
||||
|
||||
@pytest.mark.part2
|
||||
def test_gemma3_1b_instruct_multi_lora() -> None:
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_gemma3_1b_instruct_multi_lora(cuda_graph_config) -> None:
|
||||
model_dir = f"{llm_models_root()}/gemma/gemma-3-1b-it"
|
||||
|
||||
target_modules = ['attn_q', 'attn_k', 'attn_v']
|
||||
@ -662,7 +722,8 @@ def test_gemma3_1b_instruct_multi_lora() -> None:
|
||||
)
|
||||
llm = LLM(model_dir,
|
||||
lora_config=trtllm_lora_config,
|
||||
kv_cache_config=kv_cache_config)
|
||||
kv_cache_config=kv_cache_config,
|
||||
cuda_graph_config=cuda_graph_config)
|
||||
|
||||
prompts = [
|
||||
"Is it ok to fill diesel in a petrol car?",
|
||||
@ -745,7 +806,8 @@ def test_nemo_lora_unsupported_modules_validation(tmp_path):
|
||||
|
||||
@force_ampere
|
||||
@pytest.mark.part1
|
||||
def test_gqa_nemo_lora(tmp_path):
|
||||
@test_lora_with_and_without_cuda_graph
|
||||
def test_gqa_nemo_lora(tmp_path, cuda_graph_config):
|
||||
"""
|
||||
Test NeMo-format LoRA checkpoint loading and GQA support in TinyLlama.
|
||||
|
||||
@ -790,6 +852,7 @@ def test_gqa_nemo_lora(tmp_path):
|
||||
model=model_path,
|
||||
lora_config=lora_config,
|
||||
kv_cache_config=global_kvcache_config,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -32,3 +32,33 @@ def test_generate_api_docs_as_docstring():
|
||||
doc = generate_api_docs_as_docstring(LlmArgs)
|
||||
assert ":tag:`beta`" in doc, "the label is not generated"
|
||||
print(doc)
|
||||
|
||||
|
||||
class DelayedAssert:
|
||||
|
||||
def __init__(self, store_stack: bool = False):
|
||||
self.assertions = []
|
||||
self.store_stack = store_stack
|
||||
|
||||
def add(self, result: bool, msg: str):
|
||||
import traceback
|
||||
self.assertions.append(
|
||||
(bool(result), str(msg), traceback.format_stack()))
|
||||
|
||||
def get_msg(self):
|
||||
ret = ['Some assertions failed:']
|
||||
for result, msg, stack in self.assertions:
|
||||
ret.append('\n'.join([
|
||||
f'Assert result: {result}', msg,
|
||||
''.join(stack) if self.store_stack else ''
|
||||
]))
|
||||
ret = '\n-----------------------------------------\n'.join(ret)
|
||||
ret = 'Some assertions failed:\n' + ret
|
||||
return ret
|
||||
|
||||
def clear(self):
|
||||
self.assertions.clear()
|
||||
|
||||
def assert_all(self):
|
||||
assert all(ret[0] for ret in self.assertions), self.get_msg()
|
||||
self.clear()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user