diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp b/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp index 1bceceae80..41aba403f4 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp @@ -30,6 +30,8 @@ using namespace batchedGemm::batchedGemm; using namespace batchedGemm::gemm; using namespace batchedGemm::trtllm::gen; +static BatchedGemmInterface::ModuleCache globalTrtllmGenBatchedGemmModuleCache; + std::vector prioritizePredefinedConfigs(int m, int n, int k, std::vector const& sortedIndices, batchedGemm::batchedGemm::BatchedGemmConfig const* configs) { @@ -295,7 +297,8 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere bmm.runInitBeforeWorldSync(config, gemmData, static_cast(stream)); - auto const err = bmm.run(config, workspace, gemmData, static_cast(stream), multiProcessorCount); + auto const err = bmm.run(config, workspace, gemmData, static_cast(stream), multiProcessorCount, + globalTrtllmGenBatchedGemmModuleCache); TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!"); } diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h index 4251a333e0..92e0a4cf00 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include "BatchedGemmOptions.h" #include "KernelParams.h" @@ -392,12 +393,14 @@ struct BatchedGemmData class BatchedGemmInterface { public: + using ModuleCache = std::unordered_map>; + BatchedGemmInterface() {} // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. // Provided config must be validated with isValidConfig before the call. int32_t run(BatchedGemmConfig const& config, void* workspace, BatchedGemmData const& options, void* cudaStream, - int32_t multiProcessorCount); + int32_t multiProcessorCount, std::optional> moduleCache = std::nullopt); // Initializes the buffers before the world sync. Must be called before run. int32_t runInitBeforeWorldSync( @@ -579,9 +582,9 @@ std::vector BatchedGemmInterface::getWorkspaceSizesInBytes( } //////////////////////////////////////////////////////////////////////////////////////////////////// - int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace, - BatchedGemmData const& batchedGemmData, void* cudaStream, int32_t /* multiProcessorCount */) + BatchedGemmData const& batchedGemmData, void* cudaStream, int32_t /* multiProcessorCount */, + std::optional> moduleCache) { // Get options from config and data. auto options = getOptionsFromConfigAndData(config, batchedGemmData); @@ -652,8 +655,42 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa #ifdef TLLM_GEN_EXPORT_INTERFACE CUmodule cuModule; CUfunction cuFunction; - cuModuleLoadData(&cuModule, config.mData); - cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + if (moduleCache.has_value()) + { + ModuleCache& moduleCacheRef = moduleCache.value().get(); + + // Modules are associated with a specific context so include the ctxId in the key + CUcontext ctx; + unsigned long long ctxId; + cuCtxGetCurrent(&ctx); + cuCtxGetId(ctx, &ctxId); + + // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal + // representation. + std::string const ctxName + = std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); + std::string const funcName = std::string(config.mFunctionName); + // As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision + auto const moduleKey = ctxName + funcName; + auto module = moduleCacheRef.find(moduleKey); + + // Check if module exists in cache. Otherwise, load it + if (module != moduleCacheRef.end()) + { + cuFunction = std::get<1>(module->second); + } + else + { + cuModuleLoadData(&cuModule, config.mData); + cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); + } + } + else + { + cuModuleLoadData(&cuModule, config.mData); + cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + } // Prepare the grid/block. dim3 block3{static_cast(config.mNumThreadsPerCTA), static_cast(1), static_cast(1)}; @@ -673,6 +710,11 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa { return -1; } + // If a module cache has not been given, unload the module to avoid overflow + if (!moduleCache.has_value()) + { + cuModuleUnload(cuModule); + } #else config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid); #endif diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp b/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp index e8ae5e0715..761fb475de 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp @@ -30,6 +30,8 @@ namespace kernels namespace tg = gemm::trtllm::gen; using namespace gemm::gemm; +static GemmInterface::ModuleCache globalTrtllmGenGemmModuleCache; + TrtllmGenGemmRunner::TrtllmGenGemmRunner(TrtllmGenGemmRunnerOptions const& options_) : mOptions(options_) { @@ -111,7 +113,8 @@ void TrtllmGenGemmRunner::run(int32_t m, int32_t n, int32_t k, void const* a, fl // FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere gemm.runInitBeforeWorldSync(config, gemmData, static_cast(stream)); - auto const err = gemm.run(config, workspace, gemmData, static_cast(stream), multiProcessorCount); + auto const err = gemm.run( + config, workspace, gemmData, static_cast(stream), multiProcessorCount, globalTrtllmGenGemmModuleCache); TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!"); } diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/trtllmGen_gemm_export/GemmInterface.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/trtllmGen_gemm_export/GemmInterface.h index fc53d0ad29..4cb4c6538a 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/trtllmGen_gemm_export/GemmInterface.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/trtllmGen_gemm_export/GemmInterface.h @@ -222,12 +222,15 @@ struct GemmData class GemmInterface { public: + using ModuleCache = std::unordered_map>; + GemmInterface() {} // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. // Provided config must be validated with isValidConfig before the call. int32_t run(GemmConfig const& config, void* workspace, GemmData const& options, void* cudaStream, - int32_t multiProcessorCount) const; + int32_t multiProcessorCount, + std::optional> moduleCache = std::nullopt) const; // Initializes the buffers before the world sync. Must be called before run. int32_t runInitBeforeWorldSync(GemmConfig const& config, GemmData const& data, void* cudaStream) const; @@ -384,7 +387,7 @@ bool GemmInterface::isValidConfig(GemmConfig const& config, GemmData const& data //////////////////////////////////////////////////////////////////////////////////////////////////// int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData const& data, void* cudaStream, - int32_t multiProcessorCount) const + int32_t multiProcessorCount, std::optional> moduleCache) const { // Get options from config and data. auto options = getOptionsFromConfigAndData(config, data); @@ -439,8 +442,42 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c #ifdef TLLM_GEN_EXPORT_INTERFACE CUmodule cuModule; CUfunction cuFunction; - cuModuleLoadData(&cuModule, config.mData); - cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + if (moduleCache.has_value()) + { + ModuleCache& moduleCacheRef = moduleCache.value().get(); + + // Modules are associated with a specific context so include the ctxId in the key + CUcontext ctx; + unsigned long long ctxId; + cuCtxGetCurrent(&ctx); + cuCtxGetId(ctx, &ctxId); + + // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal + // representation. + std::string const ctxName + = std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); + std::string const funcName = std::string(config.mFunctionName); + // As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision + auto const moduleKey = ctxName + funcName; + auto module = moduleCacheRef.find(moduleKey); + + // Check if module exists in cache. Otherwise, load it + if (module != moduleCacheRef.end()) + { + cuFunction = std::get<1>(module->second); + } + else + { + cuModuleLoadData(&cuModule, config.mData); + cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); + } + } + else + { + cuModuleLoadData(&cuModule, config.mData); + cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + } // Prepare the grid/block. dim3 block3{static_cast(config.mNumThreadsPerCTA), static_cast(1), static_cast(1)}; @@ -460,6 +497,11 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c { return -1; } + // If a module cache has not been given, unload the module to avoid leaking + if (!moduleCache.has_value()) + { + cuModuleUnload(cuModule); + } #else config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid); #endif diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemmGatedAct/KernelRunner.cpp b/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemmGatedAct/KernelRunner.cpp index e37af78665..c5d5a18c00 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemmGatedAct/KernelRunner.cpp +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemmGatedAct/KernelRunner.cpp @@ -27,6 +27,8 @@ namespace tensorrt_llm namespace kernels { +static gemmGatedAct::GemmGatedActInterface::ModuleCache globalTrtllmGenGemmGatedActModuleCache; + TrtllmGenGemmGatedActRunner::TrtllmGenGemmGatedActRunner(TrtllmGenGemmGatedActRunnerOptions const& options_) : mOptions(options_) { @@ -104,7 +106,8 @@ void TrtllmGenGemmGatedActRunner::run(int32_t m, int32_t n, int32_t k, void cons // FIXME once we start using all-reduce in the epilogue of the gemm this can be moved elsewhere gemm.runInitBeforeWorldSync(config, gemmData, static_cast(stream)); - auto const err = gemm.run(config, workspace, gemmData, static_cast(stream), multiProcessorCount); + auto const err = gemm.run(config, workspace, gemmData, static_cast(stream), multiProcessorCount, + globalTrtllmGenGemmGatedActModuleCache); TLLM_CHECK_WITH_INFO(err == 0, "Error occurred when running GEMM!"); } diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemmGatedAct/trtllmGen_gatedAct_export/GemmGatedActInterface.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemmGatedAct/trtllmGen_gatedAct_export/GemmGatedActInterface.h index 7bd170f736..a8087dc59a 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemmGatedAct/trtllmGen_gatedAct_export/GemmGatedActInterface.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/gemmGatedAct/trtllmGen_gatedAct_export/GemmGatedActInterface.h @@ -183,12 +183,15 @@ struct GemmGatedActData class GemmGatedActInterface { public: + using ModuleCache = std::unordered_map>; + GemmGatedActInterface() {} // Launch the cubin from the provided config. It calls all necessary memsets for internal buffers. // Provided config must be validated with isValidConfig before the call. int32_t run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data, void* cudaStream, - int32_t multiProcessorCount) const; + int32_t multiProcessorCount, + std::optional> moduleCache = std::nullopt) const; // Initializes the buffers before the world sync. Must be called before run. int32_t runInitBeforeWorldSync( @@ -340,7 +343,7 @@ bool GemmGatedActInterface::isValidConfig(GemmGatedActConfig const& config, Gemm //////////////////////////////////////////////////////////////////////////////////////////////////// int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* workspace, GemmGatedActData const& data, - void* cudaStream, int32_t multiProcessorCount) const + void* cudaStream, int32_t multiProcessorCount, std::optional> moduleCache) const { // Get options from config and data. auto options = getOptionsFromConfigAndData(config, data); @@ -392,8 +395,42 @@ int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* works #ifdef TLLM_GEN_EXPORT_INTERFACE CUmodule cuModule; CUfunction cuFunction; - cuModuleLoadData(&cuModule, config.mData); - cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + if (moduleCache.has_value()) + { + ModuleCache& moduleCacheRef = moduleCache.value().get(); + + // Modules are associated with a specific context so include the ctxId in the key + CUcontext ctx; + unsigned long long ctxId; + cuCtxGetCurrent(&ctx); + cuCtxGetId(ctx, &ctxId); + + // Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal + // representation. + std::string const ctxName + = std::string(reinterpret_cast(&ctxId), sizeof(unsigned long long) / sizeof(char)); + std::string const funcName = std::string(config.mFunctionName); + // As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision + auto const moduleKey = ctxName + funcName; + auto module = moduleCacheRef.find(moduleKey); + + // Check if module exists in cache. Otherwise, load it + if (module != moduleCacheRef.end()) + { + cuFunction = std::get<1>(module->second); + } + else + { + cuModuleLoadData(&cuModule, config.mData); + cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction))); + } + } + else + { + cuModuleLoadData(&cuModule, config.mData); + cuModuleGetFunction(&cuFunction, cuModule, config.mFunctionName); + } // Prepare the grid/block. dim3 block3{static_cast(config.mNumThreadsPerCTA), static_cast(1), static_cast(1)}; @@ -413,6 +450,11 @@ int32_t GemmGatedActInterface::run(GemmGatedActConfig const& config, void* works { return -1; } + // If a module cache has not been given, unload the module to avoid leaking + if (!moduleCache.has_value()) + { + cuModuleUnload(cuModule); + } #else config.mCudaRunner->run((void*) &kernelParams, (void*) cudaStream, grid); #endif