mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[feat] Adds optional module cache for TRT-LLM Gen Gemm interfaces (#5743)
Signed-off-by: David Clark <215764518+davidclark-nv@users.noreply.github.com> Co-authored-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
This commit is contained in:
parent
1191555cce
commit
a1235ee978
@ -30,6 +30,8 @@ using namespace batchedGemm::batchedGemm;
|
|||||||
using namespace batchedGemm::gemm;
|
using namespace batchedGemm::gemm;
|
||||||
using namespace batchedGemm::trtllm::gen;
|
using namespace batchedGemm::trtllm::gen;
|
||||||
|
|
||||||
|
static BatchedGemmInterface::ModuleCache globalTrtllmGenBatchedGemmModuleCache;
|
||||||
|
|
||||||
std::vector<int64_t> prioritizePredefinedConfigs(int m, int n, int k, std::vector<int64_t> const& sortedIndices,
|
std::vector<int64_t> prioritizePredefinedConfigs(int m, int n, int k, std::vector<int64_t> const& sortedIndices,
|
||||||
batchedGemm::batchedGemm::BatchedGemmConfig const* configs)
|
batchedGemm::batchedGemm::BatchedGemmConfig const* configs)
|
||||||
{
|
{
|
||||||
@ -295,7 +297,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
|
|||||||
// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
|
// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
|
||||||
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
|
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
|
||||||
|
|
||||||
auto const err = bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount);
|
auto const err = bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
|
||||||
|
globalTrtllmGenBatchedGemmModuleCache);
|
||||||
|
|
||||||
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
|
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "BatchedGemmOptions.h"
|
#include "BatchedGemmOptions.h"
|
||||||
#include "KernelParams.h"
|
#include "KernelParams.h"
|
||||||
@ -392,12 +393,14 @@ struct BatchedGemmData
|
|||||||
class BatchedGemmInterface
|
class BatchedGemmInterface
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
using ModuleCache = std::unordered_map<std::string, std::tuple<CUmodule, CUfunction>>;
|
||||||
|
|
||||||
BatchedGemmInterface() {}
|
BatchedGemmInterface() {}
|
||||||
|
|
||||||
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
|
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
|
||||||
// Provided config must be validated with isValidConfig before the call.
|
// Provided config must be validated with isValidConfig before the call.
|
||||||
int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& options, void* cudaStream,
|
int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& options, void* cudaStream,
|
||||||
int32_t multiProcessorCount);
|
int32_t multiProcessorCount, std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt);
|
||||||
|
|
||||||
// Initializes the buffers before the world sync. Must be called before run.
|
// Initializes the buffers before the world sync. Must be called before run.
|
||||||
int32_t runInitBeforeWorldSync(
|
int32_t runInitBeforeWorldSync(
|
||||||
@ -579,9 +582,9 @@ std::vector<size_t> BatchedGemmInterface::getWorkspaceSizesInBytes(
|
|||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace,
|
int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace,
|
||||||
BatchedGemmData const& batchedGemmData, void* cudaStream, int32_t /* multiProcessorCount */)
|
BatchedGemmData const& batchedGemmData, void* cudaStream, int32_t /* multiProcessorCount */,
|
||||||
|
std::optional<std::reference_wrapper<ModuleCache>> moduleCache)
|
||||||
{
|
{
|
||||||
// Get options from config and data.
|
// Get options from config and data.
|
||||||
auto options = getOptionsFromConfigAndData(config, batchedGemmData);
|
auto options = getOptionsFromConfigAndData(config, batchedGemmData);
|
||||||
@ -652,8 +655,42 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
|
|||||||
#ifdef TLLM_GEN_EXPORT_INTERFACE
|
#ifdef TLLM_GEN_EXPORT_INTERFACE
|
||||||
CUmodule cuModule;
|
CUmodule cuModule;
|
||||||
CUfunction cuFunction;
|
CUfunction cuFunction;
|
||||||
cuModuleLoadData(&cuModule, config.mData);
|
if (moduleCache.has_value())
|
||||||
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
|
{
|
||||||
|
ModuleCache& moduleCacheRef = moduleCache.value().get();
|
||||||
|
|
||||||
|
// Modules are associated with a specific context so include the ctxId in the key
|
||||||
|
CUcontext ctx;
|
||||||
|
unsigned long long ctxId;
|
||||||
|
cuCtxGetCurrent(&ctx);
|
||||||
|
cuCtxGetId(ctx, &ctxId);
|
||||||
|
|
||||||
|
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
|
||||||
|
// representation.
|
||||||
|
std::string const ctxName
|
||||||
|
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
|
||||||
|
std::string const funcName = std::string(config.mFunctionName);
|
||||||
|
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
|
||||||
|
auto const moduleKey = ctxName + funcName;
|
||||||
|
auto module = moduleCacheRef.find(moduleKey);
|
||||||
|
|
||||||
|
// Check if module exists in cache. Otherwise, load it
|
||||||
|
if (module != moduleCacheRef.end())
|
||||||
|
{
|
||||||
|
cuFunction = std::get<1>(module->second);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cuModuleLoadData(&cuModule, config.mData);
|
||||||
|
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
|
||||||
|
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cuModuleLoadData(&cuModule, config.mData);
|
||||||
|
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare the grid/block.
|
// Prepare the grid/block.
|
||||||
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
|
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
|
||||||
@ -673,6 +710,11 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
|
|||||||
{
|
{
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
// If a module cache has not been given, unload the module to avoid overflow
|
||||||
|
if (!moduleCache.has_value())
|
||||||
|
{
|
||||||
|
cuModuleUnload(cuModule);
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
|
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -30,6 +30,8 @@ namespace kernels
|
|||||||
namespace tg = gemm::trtllm::gen;
|
namespace tg = gemm::trtllm::gen;
|
||||||
using namespace gemm::gemm;
|
using namespace gemm::gemm;
|
||||||
|
|
||||||
|
static GemmInterface::ModuleCache globalTrtllmGenGemmModuleCache;
|
||||||
|
|
||||||
TrtllmGenGemmRunner::TrtllmGenGemmRunner(TrtllmGenGemmRunnerOptions const& options_)
|
TrtllmGenGemmRunner::TrtllmGenGemmRunner(TrtllmGenGemmRunnerOptions const& options_)
|
||||||
: mOptions(options_)
|
: mOptions(options_)
|
||||||
{
|
{
|
||||||
@ -111,7 +113,8 @@ void TrtllmGenGemmRunner::run(int32_t m, int32_t n, int32_t k, void const* a, fl
|
|||||||
// FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere
|
// FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere
|
||||||
gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
|
gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
|
||||||
|
|
||||||
auto const err = gemm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount);
|
auto const err = gemm.run(
|
||||||
|
config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount, globalTrtllmGenGemmModuleCache);
|
||||||
|
|
||||||
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
|
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -222,12 +222,15 @@ struct GemmData
|
|||||||
class GemmInterface
|
class GemmInterface
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
using ModuleCache = std::unordered_map<std::string, std::tuple<CUmodule, CUfunction>>;
|
||||||
|
|
||||||
GemmInterface() {}
|
GemmInterface() {}
|
||||||
|
|
||||||
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
|
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
|
||||||
// Provided config must be validated with isValidConfig before the call.
|
// Provided config must be validated with isValidConfig before the call.
|
||||||
int32_t run(GemmConfig const& config, void* workspace, GemmData const& options, void* cudaStream,
|
int32_t run(GemmConfig const& config, void* workspace, GemmData const& options, void* cudaStream,
|
||||||
int32_t multiProcessorCount) const;
|
int32_t multiProcessorCount,
|
||||||
|
std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt) const;
|
||||||
|
|
||||||
// Initializes the buffers before the world sync. Must be called before run.
|
// Initializes the buffers before the world sync. Must be called before run.
|
||||||
int32_t runInitBeforeWorldSync(GemmConfig const& config, GemmData const& data, void* cudaStream) const;
|
int32_t runInitBeforeWorldSync(GemmConfig const& config, GemmData const& data, void* cudaStream) const;
|
||||||
@ -384,7 +387,7 @@ bool GemmInterface::isValidConfig(GemmConfig const& config, GemmData const& data
|
|||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData const& data, void* cudaStream,
|
int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData const& data, void* cudaStream,
|
||||||
int32_t multiProcessorCount) const
|
int32_t multiProcessorCount, std::optional<std::reference_wrapper<ModuleCache>> moduleCache) const
|
||||||
{
|
{
|
||||||
// Get options from config and data.
|
// Get options from config and data.
|
||||||
auto options = getOptionsFromConfigAndData(config, data);
|
auto options = getOptionsFromConfigAndData(config, data);
|
||||||
@ -439,8 +442,42 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
|
|||||||
#ifdef TLLM_GEN_EXPORT_INTERFACE
|
#ifdef TLLM_GEN_EXPORT_INTERFACE
|
||||||
CUmodule cuModule;
|
CUmodule cuModule;
|
||||||
CUfunction cuFunction;
|
CUfunction cuFunction;
|
||||||
cuModuleLoadData(&cuModule, config.mData);
|
if (moduleCache.has_value())
|
||||||
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
|
{
|
||||||
|
ModuleCache& moduleCacheRef = moduleCache.value().get();
|
||||||
|
|
||||||
|
// Modules are associated with a specific context so include the ctxId in the key
|
||||||
|
CUcontext ctx;
|
||||||
|
unsigned long long ctxId;
|
||||||
|
cuCtxGetCurrent(&ctx);
|
||||||
|
cuCtxGetId(ctx, &ctxId);
|
||||||
|
|
||||||
|
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
|
||||||
|
// representation.
|
||||||
|
std::string const ctxName
|
||||||
|
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
|
||||||
|
std::string const funcName = std::string(config.mFunctionName);
|
||||||
|
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
|
||||||
|
auto const moduleKey = ctxName + funcName;
|
||||||
|
auto module = moduleCacheRef.find(moduleKey);
|
||||||
|
|
||||||
|
// Check if module exists in cache. Otherwise, load it
|
||||||
|
if (module != moduleCacheRef.end())
|
||||||
|
{
|
||||||
|
cuFunction = std::get<1>(module->second);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cuModuleLoadData(&cuModule, config.mData);
|
||||||
|
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
|
||||||
|
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cuModuleLoadData(&cuModule, config.mData);
|
||||||
|
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare the grid/block.
|
// Prepare the grid/block.
|
||||||
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
|
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
|
||||||
@ -460,6 +497,11 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
|
|||||||
{
|
{
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
// If a module cache has not been given, unload the module to avoid leaking
|
||||||
|
if (!moduleCache.has_value())
|
||||||
|
{
|
||||||
|
cuModuleUnload(cuModule);
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
|
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -27,6 +27,8 @@ namespace tensorrt_llm
|
|||||||
namespace kernels
|
namespace kernels
|
||||||
{
|
{
|
||||||
|
|
||||||
|
static gemmGatedAct::GemmGatedActInterface::ModuleCache globalTrtllmGenGemmGatedActModuleCache;
|
||||||
|
|
||||||
TrtllmGenGemmGatedActRunner::TrtllmGenGemmGatedActRunner(TrtllmGenGemmGatedActRunnerOptions const& options_)
|
TrtllmGenGemmGatedActRunner::TrtllmGenGemmGatedActRunner(TrtllmGenGemmGatedActRunnerOptions const& options_)
|
||||||
: mOptions(options_)
|
: mOptions(options_)
|
||||||
{
|
{
|
||||||
@ -104,7 +106,8 @@ void TrtllmGenGemmGatedActRunner::run(int32_t m, int32_t n, int32_t k, void cons
|
|||||||
// FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere
|
// FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere
|
||||||
gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
|
gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
|
||||||
|
|
||||||
auto const err = gemm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount);
|
auto const err = gemm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
|
||||||
|
globalTrtllmGenGemmGatedActModuleCache);
|
||||||
|
|
||||||
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
|
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -183,12 +183,15 @@ struct GemmGatedActData
|
|||||||
class GemmGatedActInterface
|
class GemmGatedActInterface
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
|
using ModuleCache = std::unordered_map<std::string, std::tuple<CUmodule, CUfunction>>;
|
||||||
|
|
||||||
GemmGatedActInterface() {}
|
GemmGatedActInterface() {}
|
||||||
|
|
||||||
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
|
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
|
||||||
// Provided config must be validated with isValidConfig before the call.
|
// Provided config must be validated with isValidConfig before the call.
|
||||||
int32_t run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data, void* cudaStream,
|
int32_t run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data, void* cudaStream,
|
||||||
int32_t multiProcessorCount) const;
|
int32_t multiProcessorCount,
|
||||||
|
std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt) const;
|
||||||
|
|
||||||
// Initializes the buffers before the world sync. Must be called before run.
|
// Initializes the buffers before the world sync. Must be called before run.
|
||||||
int32_t runInitBeforeWorldSync(
|
int32_t runInitBeforeWorldSync(
|
||||||
@ -340,7 +343,7 @@ bool GemmGatedActInterface::isValidConfig(GemmGatedActConfig const& config, Gemm
|
|||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data,
|
int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data,
|
||||||
void* cudaStream, int32_t multiProcessorCount) const
|
void* cudaStream, int32_t multiProcessorCount, std::optional<std::reference_wrapper<ModuleCache>> moduleCache) const
|
||||||
{
|
{
|
||||||
// Get options from config and data.
|
// Get options from config and data.
|
||||||
auto options = getOptionsFromConfigAndData(config, data);
|
auto options = getOptionsFromConfigAndData(config, data);
|
||||||
@ -392,8 +395,42 @@ int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* works
|
|||||||
#ifdef TLLM_GEN_EXPORT_INTERFACE
|
#ifdef TLLM_GEN_EXPORT_INTERFACE
|
||||||
CUmodule cuModule;
|
CUmodule cuModule;
|
||||||
CUfunction cuFunction;
|
CUfunction cuFunction;
|
||||||
cuModuleLoadData(&cuModule, config.mData);
|
if (moduleCache.has_value())
|
||||||
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
|
{
|
||||||
|
ModuleCache& moduleCacheRef = moduleCache.value().get();
|
||||||
|
|
||||||
|
// Modules are associated with a specific context so include the ctxId in the key
|
||||||
|
CUcontext ctx;
|
||||||
|
unsigned long long ctxId;
|
||||||
|
cuCtxGetCurrent(&ctx);
|
||||||
|
cuCtxGetId(ctx, &ctxId);
|
||||||
|
|
||||||
|
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
|
||||||
|
// representation.
|
||||||
|
std::string const ctxName
|
||||||
|
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
|
||||||
|
std::string const funcName = std::string(config.mFunctionName);
|
||||||
|
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
|
||||||
|
auto const moduleKey = ctxName + funcName;
|
||||||
|
auto module = moduleCacheRef.find(moduleKey);
|
||||||
|
|
||||||
|
// Check if module exists in cache. Otherwise, load it
|
||||||
|
if (module != moduleCacheRef.end())
|
||||||
|
{
|
||||||
|
cuFunction = std::get<1>(module->second);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cuModuleLoadData(&cuModule, config.mData);
|
||||||
|
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
|
||||||
|
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
cuModuleLoadData(&cuModule, config.mData);
|
||||||
|
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
|
||||||
|
}
|
||||||
|
|
||||||
// Prepare the grid/block.
|
// Prepare the grid/block.
|
||||||
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
|
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
|
||||||
@ -413,6 +450,11 @@ int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* works
|
|||||||
{
|
{
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
// If a module cache has not been given, unload the module to avoid leaking
|
||||||
|
if (!moduleCache.has_value())
|
||||||
|
{
|
||||||
|
cuModuleUnload(cuModule);
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
|
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user