mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-7263][fix] Prevent recreation of cublas handles in lora_grouped_gemm every call (#6968)
Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
parent
19667304b5
commit
a54c53652b
@ -130,10 +130,14 @@ std::vector<th::Tensor> lora_grouped_gemm(th::Tensor const& input, th::Tensor co
|
||||
}
|
||||
}
|
||||
|
||||
auto cublasHandle = getCublasHandle();
|
||||
auto cublasLtHandle = getCublasLtHandle();
|
||||
auto cublasWraper
|
||||
= std::make_shared<tensorrt_llm::common::CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
|
||||
thread_local std::shared_ptr<tensorrt_llm::common::CublasMMWrapper> cublasWrapper;
|
||||
if (cublasWrapper == nullptr)
|
||||
{
|
||||
auto cublasHandle = getCublasHandle();
|
||||
auto cublasLtHandle = getCublasLtHandle();
|
||||
cublasWrapper
|
||||
= std::make_shared<tensorrt_llm::common::CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
|
||||
}
|
||||
|
||||
int const inHiddenSize = input.sizes()[input.sizes().size() - 1];
|
||||
|
||||
@ -151,7 +155,7 @@ std::vector<th::Tensor> lora_grouped_gemm(th::Tensor const& input, th::Tensor co
|
||||
}
|
||||
|
||||
auto mLoraImpl = std::make_shared<tensorrt_llm::kernels::LoraImpl>(
|
||||
inHiddenSize, outHiddenSizes, transA, transB, numLoraModules, loraRuntimeDataType, max_low_rank, cublasWraper);
|
||||
inHiddenSize, outHiddenSizes, transA, transB, numLoraModules, loraRuntimeDataType, max_low_rank, cublasWrapper);
|
||||
|
||||
// TODO (dafrimi): use Profiler to find the best tactic as used in lora_plugin
|
||||
mLoraImpl->setBestTactic(std::nullopt);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user