[TRTLLM-7263][fix] Prevent recreation of cublas handles in lora_grouped_gemm every call (#6968)

Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
amitz-nv 2025-08-19 15:39:56 +03:00 committed by GitHub
parent 19667304b5
commit a54c53652b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -130,10 +130,14 @@ std::vector<th::Tensor> lora_grouped_gemm(th::Tensor const& input, th::Tensor co
}
}
auto cublasHandle = getCublasHandle();
auto cublasLtHandle = getCublasLtHandle();
auto cublasWraper
= std::make_shared<tensorrt_llm::common::CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
thread_local std::shared_ptr<tensorrt_llm::common::CublasMMWrapper> cublasWrapper;
if (cublasWrapper == nullptr)
{
auto cublasHandle = getCublasHandle();
auto cublasLtHandle = getCublasLtHandle();
cublasWrapper
= std::make_shared<tensorrt_llm::common::CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
}
int const inHiddenSize = input.sizes()[input.sizes().size() - 1];
@ -151,7 +155,7 @@ std::vector<th::Tensor> lora_grouped_gemm(th::Tensor const& input, th::Tensor co
}
auto mLoraImpl = std::make_shared<tensorrt_llm::kernels::LoraImpl>(
inHiddenSize, outHiddenSizes, transA, transB, numLoraModules, loraRuntimeDataType, max_low_rank, cublasWraper);
inHiddenSize, outHiddenSizes, transA, transB, numLoraModules, loraRuntimeDataType, max_low_rank, cublasWrapper);
// TODO (dafrimi): use Profiler to find the best tactic as used in lora_plugin
mLoraImpl->setBestTactic(std::nullopt);