[feat] Adds optional module cache for TRT-LLM Gen Gemm interfaces (#5743)

Signed-off-by: David Clark <215764518+davidclark-nv@users.noreply.github.com>
Co-authored-by: Nikita Korobov <14355239+nekorobov@users.noreply.github.com>
This commit is contained in:
davidclark-nv 2025-07-07 13:34:55 -07:00 committed by GitHub
parent 1191555cce
commit a1235ee978
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 151 additions and 16 deletions

View File

@ -30,6 +30,8 @@ using namespace batchedGemm::batchedGemm;
using namespace batchedGemm::gemm; using namespace batchedGemm::gemm;
using namespace batchedGemm::trtllm::gen; using namespace batchedGemm::trtllm::gen;
static BatchedGemmInterface::ModuleCache globalTrtllmGenBatchedGemmModuleCache;
std::vector<int64_t> prioritizePredefinedConfigs(int m, int n, int k, std::vector<int64_t> const& sortedIndices, std::vector<int64_t> prioritizePredefinedConfigs(int m, int n, int k, std::vector<int64_t> const& sortedIndices,
batchedGemm::batchedGemm::BatchedGemmConfig const* configs) batchedGemm::batchedGemm::BatchedGemmConfig const* configs)
{ {
@ -295,7 +297,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
// FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere
bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream)); bmm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
auto const err = bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount); auto const err = bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
globalTrtllmGenBatchedGemmModuleCache);
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!"); TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
} }

View File

@ -17,6 +17,7 @@
#pragma once #pragma once
#include <numeric> #include <numeric>
#include <optional>
#include "BatchedGemmOptions.h" #include "BatchedGemmOptions.h"
#include "KernelParams.h" #include "KernelParams.h"
@ -392,12 +393,14 @@ struct BatchedGemmData
class BatchedGemmInterface class BatchedGemmInterface
{ {
public: public:
using ModuleCache = std::unordered_map<std::string, std::tuple<CUmodule, CUfunction>>;
BatchedGemmInterface() {} BatchedGemmInterface() {}
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
// Provided config must be validated with isValidConfig before the call. // Provided config must be validated with isValidConfig before the call.
int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& options, void* cudaStream, int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& options, void* cudaStream,
int32_t multiProcessorCount); int32_t multiProcessorCount, std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt);
// Initializes the buffers before the world sync. Must be called before run. // Initializes the buffers before the world sync. Must be called before run.
int32_t runInitBeforeWorldSync( int32_t runInitBeforeWorldSync(
@ -579,9 +582,9 @@ std::vector<size_t> BatchedGemmInterface::getWorkspaceSizesInBytes(
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace, int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace,
BatchedGemmData const& batchedGemmData, void* cudaStream, int32_t /* multiProcessorCount */) BatchedGemmData const& batchedGemmData, void* cudaStream, int32_t /* multiProcessorCount */,
std::optional<std::reference_wrapper<ModuleCache>> moduleCache)
{ {
// Get options from config and data. // Get options from config and data.
auto options = getOptionsFromConfigAndData(config, batchedGemmData); auto options = getOptionsFromConfigAndData(config, batchedGemmData);
@ -652,8 +655,42 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
#ifdef TLLM_GEN_EXPORT_INTERFACE #ifdef TLLM_GEN_EXPORT_INTERFACE
CUmodule cuModule; CUmodule cuModule;
CUfunction cuFunction; CUfunction cuFunction;
cuModuleLoadData(&cuModule, config.mData); if (moduleCache.has_value())
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); {
ModuleCache& moduleCacheRef = moduleCache.value().get();
// Modules are associated with a specific context so include the ctxId in the key
CUcontext ctx;
unsigned long long ctxId;
cuCtxGetCurrent(&ctx);
cuCtxGetId(ctx, &ctxId);
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
// representation.
std::string const ctxName
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
std::string const funcName = std::string(config.mFunctionName);
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
auto const moduleKey = ctxName + funcName;
auto module = moduleCacheRef.find(moduleKey);
// Check if module exists in cache. Otherwise, load it
if (module != moduleCacheRef.end())
{
cuFunction = std::get<1>(module->second);
}
else
{
cuModuleLoadData(&cuModule, config.mData);
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
}
}
else
{
cuModuleLoadData(&cuModule, config.mData);
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
}
// Prepare the grid/block. // Prepare the grid/block.
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)}; dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
@ -673,6 +710,11 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
{ {
return -1; return -1;
} }
// If a module cache has not been given, unload the module to avoid overflow
if (!moduleCache.has_value())
{
cuModuleUnload(cuModule);
}
#else #else
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid); config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
#endif #endif

View File

@ -30,6 +30,8 @@ namespace kernels
namespace tg = gemm::trtllm::gen; namespace tg = gemm::trtllm::gen;
using namespace gemm::gemm; using namespace gemm::gemm;
static GemmInterface::ModuleCache globalTrtllmGenGemmModuleCache;
TrtllmGenGemmRunner::TrtllmGenGemmRunner(TrtllmGenGemmRunnerOptions const& options_) TrtllmGenGemmRunner::TrtllmGenGemmRunner(TrtllmGenGemmRunnerOptions const& options_)
: mOptions(options_) : mOptions(options_)
{ {
@ -111,7 +113,8 @@ void TrtllmGenGemmRunner::run(int32_t m, int32_t n, int32_t k, void const* a, fl
// FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere // FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere
gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream)); gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
auto const err = gemm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount); auto const err = gemm.run(
config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount, globalTrtllmGenGemmModuleCache);
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!"); TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
} }

View File

@ -222,12 +222,15 @@ struct GemmData
class GemmInterface class GemmInterface
{ {
public: public:
using ModuleCache = std::unordered_map<std::string, std::tuple<CUmodule, CUfunction>>;
GemmInterface() {} GemmInterface() {}
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
// Provided config must be validated with isValidConfig before the call. // Provided config must be validated with isValidConfig before the call.
int32_t run(GemmConfig const& config, void* workspace, GemmData const& options, void* cudaStream, int32_t run(GemmConfig const& config, void* workspace, GemmData const& options, void* cudaStream,
int32_t multiProcessorCount) const; int32_t multiProcessorCount,
std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt) const;
// Initializes the buffers before the world sync. Must be called before run. // Initializes the buffers before the world sync. Must be called before run.
int32_t runInitBeforeWorldSync(GemmConfig const& config, GemmData const& data, void* cudaStream) const; int32_t runInitBeforeWorldSync(GemmConfig const& config, GemmData const& data, void* cudaStream) const;
@ -384,7 +387,7 @@ bool GemmInterface::isValidConfig(GemmConfig const& config, GemmData const& data
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData const& data, void* cudaStream, int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData const& data, void* cudaStream,
int32_t multiProcessorCount) const int32_t multiProcessorCount, std::optional<std::reference_wrapper<ModuleCache>> moduleCache) const
{ {
// Get options from config and data. // Get options from config and data.
auto options = getOptionsFromConfigAndData(config, data); auto options = getOptionsFromConfigAndData(config, data);
@ -439,8 +442,42 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
#ifdef TLLM_GEN_EXPORT_INTERFACE #ifdef TLLM_GEN_EXPORT_INTERFACE
CUmodule cuModule; CUmodule cuModule;
CUfunction cuFunction; CUfunction cuFunction;
cuModuleLoadData(&cuModule, config.mData); if (moduleCache.has_value())
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); {
ModuleCache& moduleCacheRef = moduleCache.value().get();
// Modules are associated with a specific context so include the ctxId in the key
CUcontext ctx;
unsigned long long ctxId;
cuCtxGetCurrent(&ctx);
cuCtxGetId(ctx, &ctxId);
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
// representation.
std::string const ctxName
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
std::string const funcName = std::string(config.mFunctionName);
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
auto const moduleKey = ctxName + funcName;
auto module = moduleCacheRef.find(moduleKey);
// Check if module exists in cache. Otherwise, load it
if (module != moduleCacheRef.end())
{
cuFunction = std::get<1>(module->second);
}
else
{
cuModuleLoadData(&cuModule, config.mData);
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
}
}
else
{
cuModuleLoadData(&cuModule, config.mData);
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
}
// Prepare the grid/block. // Prepare the grid/block.
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)}; dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
@ -460,6 +497,11 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c
{ {
return -1; return -1;
} }
// If a module cache has not been given, unload the module to avoid leaking
if (!moduleCache.has_value())
{
cuModuleUnload(cuModule);
}
#else #else
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid); config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
#endif #endif

View File

@ -27,6 +27,8 @@ namespace tensorrt_llm
namespace kernels namespace kernels
{ {
static gemmGatedAct::GemmGatedActInterface::ModuleCache globalTrtllmGenGemmGatedActModuleCache;
TrtllmGenGemmGatedActRunner::TrtllmGenGemmGatedActRunner(TrtllmGenGemmGatedActRunnerOptions const& options_) TrtllmGenGemmGatedActRunner::TrtllmGenGemmGatedActRunner(TrtllmGenGemmGatedActRunnerOptions const& options_)
: mOptions(options_) : mOptions(options_)
{ {
@ -104,7 +106,8 @@ void TrtllmGenGemmGatedActRunner::run(int32_t m, int32_t n, int32_t k, void cons
// FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere // FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere
gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream)); gemm.runInitBeforeWorldSync(config, gemmData, static_cast<void*>(stream));
auto const err = gemm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount); auto const err = gemm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
globalTrtllmGenGemmGatedActModuleCache);
TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!"); TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!");
} }

View File

@ -183,12 +183,15 @@ struct GemmGatedActData
class GemmGatedActInterface class GemmGatedActInterface
{ {
public: public:
using ModuleCache = std::unordered_map<std::string, std::tuple<CUmodule, CUfunction>>;
GemmGatedActInterface() {} GemmGatedActInterface() {}
// Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers.
// Provided config must be validated with isValidConfig before the call. // Provided config must be validated with isValidConfig before the call.
int32_t run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data, void* cudaStream, int32_t run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data, void* cudaStream,
int32_t multiProcessorCount) const; int32_t multiProcessorCount,
std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt) const;
// Initializes the buffers before the world sync. Must be called before run. // Initializes the buffers before the world sync. Must be called before run.
int32_t runInitBeforeWorldSync( int32_t runInitBeforeWorldSync(
@ -340,7 +343,7 @@ bool GemmGatedActInterface::isValidConfig(GemmGatedActConfig const& config, Gemm
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data, int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data,
void* cudaStream, int32_t multiProcessorCount) const void* cudaStream, int32_t multiProcessorCount, std::optional<std::reference_wrapper<ModuleCache>> moduleCache) const
{ {
// Get options from config and data. // Get options from config and data.
auto options = getOptionsFromConfigAndData(config, data); auto options = getOptionsFromConfigAndData(config, data);
@ -392,8 +395,42 @@ int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* works
#ifdef TLLM_GEN_EXPORT_INTERFACE #ifdef TLLM_GEN_EXPORT_INTERFACE
CUmodule cuModule; CUmodule cuModule;
CUfunction cuFunction; CUfunction cuFunction;
cuModuleLoadData(&cuModule, config.mData); if (moduleCache.has_value())
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); {
ModuleCache& moduleCacheRef = moduleCache.value().get();
// Modules are associated with a specific context so include the ctxId in the key
CUcontext ctx;
unsigned long long ctxId;
cuCtxGetCurrent(&ctx);
cuCtxGetId(ctx, &ctxId);
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
// representation.
std::string const ctxName
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
std::string const funcName = std::string(config.mFunctionName);
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
auto const moduleKey = ctxName + funcName;
auto module = moduleCacheRef.find(moduleKey);
// Check if module exists in cache. Otherwise, load it
if (module != moduleCacheRef.end())
{
cuFunction = std::get<1>(module->second);
}
else
{
cuModuleLoadData(&cuModule, config.mData);
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
}
}
else
{
cuModuleLoadData(&cuModule, config.mData);
cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName);
}
// Prepare the grid/block. // Prepare the grid/block.
dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)}; dim3 block3{static_cast<uint32_t>(config.mNumThreadsPerCTA), static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
@ -413,6 +450,11 @@ int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* works
{ {
return -1; return -1;
} }
// If a module cache has not been given, unload the module to avoid leaking
if (!moduleCache.has_value())
{
cuModuleUnload(cuModule);
}
#else #else
config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid); config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid);
#endif #endif