TensorRT-LLMs/cpp/include/tensorrt_llm/deep_gemm/runtime.cuh
Kaiyu Xie 3aa6b11d13
Update TensorRT-LLM (#2936)
* Update TensorRT-LLM

---------

Co-authored-by: changcui <cuichang147@gmail.com>
2025-03-18 21:25:19 +08:00

348 lines
11 KiB
Plaintext

/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#ifdef _WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif
#include "jit_utils.cuh"
#include "scheduler.cuh"
namespace deep_gemm::jit
{
/**
* C++ implementation of the Runtime class from runtime.py
* Loads and executes JIT-compiled kernels
*/
class Runtime
{
public:
Runtime(std::string const& path, deep_gemm::GemmType gemm_type);
~Runtime();
static bool isPathValid(std::string const& path);
template <typename... Args>
void operator()(Args&&... args)
{
// Load shared object if not already loaded
if (!lib_)
{
std::filesystem::path libPath = std::filesystem::path(path_);
#ifdef _WIN32
libPath /= "kernel.dll";
lib_ = LoadLibraryA(libPath.string().c_str());
if (!lib_)
{
throw std::runtime_error("Failed to load DLL: " + std::to_string(GetLastError()));
}
// Load launch function
switch (gemm_type_)
{
case deep_gemm::GemmType::Normal:
launchFuncNormal_ = reinterpret_cast<LaunchFuncNormal>(GetProcAddress((HMODULE) lib_, "launch"));
if (!launchFuncNormal_)
{
throw std::runtime_error("Failed to find launch function: " + std::to_string(GetLastError()));
}
break;
case deep_gemm::GemmType::GroupedWithOffset:
launchFuncGroupedWithOffset_
= reinterpret_cast<LaunchFuncGroupedWithOffset>(GetProcAddress((HMODULE) lib_, "launch"));
if (!launchFuncGroupedWithOffset_)
{
throw std::runtime_error("Failed to find launch function: " + std::to_string(GetLastError()));
}
break;
case deep_gemm::GemmType::StridedBatched:
launchFuncStridedBatched_
= reinterpret_cast<LaunchFuncStridedBatched>(GetProcAddress((HMODULE) lib_, "launch"));
if (!launchFuncStridedBatched_)
{
throw std::runtime_error("Failed to find launch function: " + std::to_string(GetLastError()));
}
break;
default: throw std::runtime_error("Unsupported gemm type: " + gemm_type_to_string(gemm_type_));
}
#else
libPath /= "kernel.so";
lib_ = dlopen(libPath.c_str(), RTLD_LAZY);
if (!lib_)
{
throw std::runtime_error("Failed to load shared object: " + std::string(dlerror()));
}
// Load launch function
switch (gemm_type_)
{
case deep_gemm::GemmType::Normal:
launchFuncNormal_ = reinterpret_cast<LaunchFuncNormal>(dlsym(lib_, "launch"));
if (!launchFuncNormal_)
{
throw std::runtime_error("Failed to find launch function: " + std::string(dlerror()));
}
break;
case deep_gemm::GemmType::GroupedWithOffset:
launchFuncGroupedWithOffset_ = reinterpret_cast<LaunchFuncGroupedWithOffset>(dlsym(lib_, "launch"));
if (!launchFuncGroupedWithOffset_)
{
throw std::runtime_error("Failed to find launch function: " + std::string(dlerror()));
}
break;
case deep_gemm::GemmType::StridedBatched:
launchFuncStridedBatched_ = reinterpret_cast<LaunchFuncStridedBatched>(dlsym(lib_, "launch"));
if (!launchFuncStridedBatched_)
{
throw std::runtime_error("Failed to find launch function: " + std::string(dlerror()));
}
break;
default: throw std::runtime_error("Unsupported gemm type: " + gemm_type_to_string(gemm_type_));
}
#endif
}
// Call the launch function with the provided arguments
switch (gemm_type_)
{
case deep_gemm::GemmType::Normal: callNormal(std::forward<Args>(args)...); break;
case deep_gemm::GemmType::GroupedWithOffset: callGroupedWithOffset(std::forward<Args>(args)...); break;
case deep_gemm::GemmType::StridedBatched: callStridedBatched(std::forward<Args>(args)...); break;
default: throw std::runtime_error("Unsupported gemm type: " + gemm_type_to_string(gemm_type_));
}
}
private:
using LaunchFuncNormal
= void (*)(void*, int, void*, int, void*, int, float*, float*, uint32_t, int*, cudaStream_t, int, uint32_t);
using LaunchFuncGroupedWithOffset = void (*)(
void*, int, void*, int, void*, int, float*, float*, int64_t*, int64_t*, cudaStream_t, int, uint32_t, uint32_t);
using LaunchFuncStridedBatched = void (*)(void*, uint64_t, uint64_t, void*, uint64_t, uint64_t, void*, uint64_t,
uint64_t, float*, float*, uint32_t, uint32_t, cudaStream_t, int, uint32_t);
std::string path_;
void* lib_;
deep_gemm::GemmType gemm_type_;
LaunchFuncNormal launchFuncNormal_;
LaunchFuncGroupedWithOffset launchFuncGroupedWithOffset_;
LaunchFuncStridedBatched launchFuncStridedBatched_;
// Helper method for Normal GEMM - checks for correct number of arguments
template <typename... ArgsT>
void callNormal(ArgsT&&... args)
{
constexpr size_t expected_args = 13;
constexpr size_t actual_args = sizeof...(args);
if constexpr (actual_args != expected_args)
{
throw std::invalid_argument(
"Normal GEMM requires exactly 13 arguments, but " + std::to_string(actual_args) + " were provided");
}
else
{
launchFuncNormal_(std::forward<ArgsT>(args)...);
}
}
// Helper method for GroupedWithOffset GEMM - checks for correct number of arguments
template <typename... ArgsT>
void callGroupedWithOffset(ArgsT&&... args)
{
constexpr size_t expected_args = 14;
constexpr size_t actual_args = sizeof...(args);
if constexpr (actual_args != expected_args)
{
throw std::invalid_argument("GroupedWithOffset GEMM requires exactly 14 arguments, but "
+ std::to_string(actual_args) + " were provided");
}
else
{
launchFuncGroupedWithOffset_(std::forward<ArgsT>(args)...);
}
}
// Helper method for StridedBatched GEMM - checks for correct number of arguments
template <typename... ArgsT>
void callStridedBatched(ArgsT&&... args)
{
constexpr size_t expected_args = 16;
constexpr size_t actual_args = sizeof...(args);
if constexpr (actual_args != expected_args)
{
throw std::invalid_argument("StridedBatched GEMM requires exactly 16 arguments, but "
+ std::to_string(actual_args) + " were provided");
}
else
{
launchFuncStridedBatched_(std::forward<ArgsT>(args)...);
}
}
};
/**
* C++ implementation of the RuntimeCache class from runtime.py
* Caches Runtime instances by path
*/
class RuntimeCache
{
public:
static RuntimeCache& getInstance();
Runtime* operator[](std::string const& path);
void set(std::string const& path, std::unique_ptr<Runtime> runtime);
private:
// Private constructor for singleton pattern
RuntimeCache() = default;
// Delete copy constructor and assignment operator
RuntimeCache(RuntimeCache const&) = delete;
RuntimeCache& operator=(RuntimeCache const&) = delete;
std::unordered_map<std::string, std::unique_ptr<Runtime>> cache_;
};
// Global function to access the singleton
RuntimeCache& getGlobalRuntimeCache();
} // namespace deep_gemm::jit
namespace deep_gemm::jit
{
// Runtime implementation
Runtime::Runtime(std::string const& path, deep_gemm::GemmType gemm_type)
: path_(path)
, gemm_type_(gemm_type)
, lib_(nullptr)
{
assert(isPathValid(path_));
}
Runtime::~Runtime()
{
if (lib_)
{
#ifdef _WIN32
FreeLibrary(static_cast<HMODULE>(lib_));
#else
dlclose(lib_);
#endif
}
}
bool Runtime::isPathValid(std::string const& path)
{
// Check if path exists and is a directory
if (!std::filesystem::exists(path) || !std::filesystem::is_directory(path))
{
return false;
}
// Check if all necessary files exist
#ifdef _WIN32
std::string soName = "kernel.dll";
#else
std::string soName = "kernel.so";
#endif
std::vector<std::string> requiredFiles = {"kernel.cu", soName};
for (auto const& file : requiredFiles)
{
if (!std::filesystem::exists(std::filesystem::path(path) / file))
{
return false;
}
}
return true;
}
// RuntimeCache implementation
RuntimeCache& RuntimeCache::getInstance()
{
static RuntimeCache instance;
return instance;
}
Runtime* RuntimeCache::operator[](std::string const& path)
{
// Check if already in cache
auto it = cache_.find(path);
if (it != cache_.end())
{
return it->second.get();
}
// Check if already compiled
if (Runtime::isPathValid(path))
{
// Parse path to get gemm type
std::string gemm_type_str = path.substr(path.find_last_of('_') + 1);
deep_gemm::GemmType gemm_type;
if (gemm_type_str == "Normal")
{
gemm_type = deep_gemm::GemmType::Normal;
}
else if (gemm_type_str == "GroupedWithOffset")
{
gemm_type = deep_gemm::GemmType::GroupedWithOffset;
}
else if (gemm_type_str == "StridedBatched")
{
gemm_type = deep_gemm::GemmType::StridedBatched;
}
else
{
throw std::runtime_error("Unsupported gemm type: " + gemm_type_str);
}
auto runtime = std::make_unique<Runtime>(path, gemm_type);
Runtime* result = runtime.get();
cache_[path] = std::move(runtime);
return result;
}
return nullptr;
}
void RuntimeCache::set(std::string const& path, std::unique_ptr<Runtime> runtime)
{
cache_[path] = std::move(runtime);
}
// Global function to access the RuntimeCache singleton
RuntimeCache& getGlobalRuntimeCache()
{
return RuntimeCache::getInstance();
}
} // namespace deep_gemm::jit