mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Enable nanobind as the default binding library (#6608)
Signed-off-by: Linda-Stadter <57756729+Linda-Stadter@users.noreply.github.com>
This commit is contained in:
parent
a49cf684f8
commit
898f37faa0
@ -69,7 +69,7 @@ add_compile_definitions("TLLM_GEN_EXPORT_INTERFACE")
|
||||
add_compile_definitions("TLLM_ENABLE_CUDA")
|
||||
|
||||
set(BINDING_TYPE
|
||||
"pybind"
|
||||
"nanobind"
|
||||
CACHE STRING
|
||||
"Binding type of Python bindings for C++ runtime and batch manager")
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ target_link_libraries(
|
||||
${Python3_LIBRARIES}
|
||||
${TORCH_LIBRARIES}
|
||||
torch_python
|
||||
CUDA::cuda_driver
|
||||
${CUDA_DRV_LIB}
|
||||
${CUDA_NVML_LIB}
|
||||
th_common)
|
||||
target_compile_definitions(
|
||||
|
||||
@ -285,5 +285,35 @@ struct type_caster<std::vector<std::reference_wrapper<T const>>>
|
||||
return make_caster<std::vector<T>>::from_cpp(result, policy, cleanup);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<torch::ScalarType>
|
||||
{
|
||||
NB_TYPE_CASTER(torch::ScalarType, const_name("torch.dtype"));
|
||||
|
||||
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept
|
||||
{
|
||||
std::string dtype_name = nb::cast<std::string>(nb::str(src));
|
||||
if (dtype_name.substr(0, 6) == "torch.")
|
||||
{
|
||||
dtype_name = dtype_name.substr(6);
|
||||
}
|
||||
|
||||
auto const& dtype_map = c10::getStringToDtypeMap();
|
||||
auto it = dtype_map.find(dtype_name);
|
||||
if (it != dtype_map.end())
|
||||
{
|
||||
value = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static handle from_cpp(torch::ScalarType src, rv_policy policy, cleanup_list* cleanup)
|
||||
{
|
||||
throw std::runtime_error("from_cpp for torch::ScalarType is not implemented");
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
} // namespace NB_NAMESPACE
|
||||
|
||||
@ -240,7 +240,8 @@ void initBindings(nb::module_& m)
|
||||
nb::class_<tle::KVCacheEvent>(executor_kv_cache, "KVCacheEvent")
|
||||
.def_ro("event_id", &tle::KVCacheEvent::eventId)
|
||||
.def_ro("data", &tle::KVCacheEvent::data)
|
||||
.def_ro("window_size", &tle::KVCacheEvent::windowSize);
|
||||
.def_ro("window_size", &tle::KVCacheEvent::windowSize)
|
||||
.def_ro("attention_dp_rank", &tle::KVCacheEvent::attentionDpRank);
|
||||
|
||||
nb::class_<tle::KVCacheEventManager>(executor_kv_cache, "KVCacheEventManager")
|
||||
.def(
|
||||
|
||||
@ -27,6 +27,7 @@
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/chrono.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/list.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
|
||||
@ -279,7 +279,7 @@ void initBindings(nb::module_& m)
|
||||
.def(nb::init<tr::GptDecoderBatched::CudaStreamPtr>(), nb::arg("stream"))
|
||||
.def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_num_sequences"),
|
||||
nb::arg("max_beam_width"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"))
|
||||
.def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("output"), nb::arg("input"))
|
||||
.def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("decoder_state"), nb::arg("input"))
|
||||
.def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, nb::rv_policy::reference)
|
||||
.def("finalize", &tr::GptDecoderBatched::finalize, nb::arg("decoder_state"), nb::arg("batch_idx"),
|
||||
nb::arg("sampling_config"), nb::arg("streaming"))
|
||||
|
||||
@ -44,7 +44,7 @@ target_link_libraries(
|
||||
${Python3_LIBRARIES}
|
||||
${TORCH_LIBRARIES}
|
||||
torch_python
|
||||
CUDA::cuda_driver
|
||||
${CUDA_DRV_LIB}
|
||||
${CUDA_NVML_LIB}
|
||||
th_common)
|
||||
target_compile_definitions(
|
||||
|
||||
@ -48,10 +48,10 @@ CONFIG_LINUX_AARCH64 = "linux_aarch64"
|
||||
def CONFIG_LINUX_AARCH64_LLVM = "linux_aarch64_LLVM"
|
||||
|
||||
@Field
|
||||
def CONFIG_LINUX_X86_64_NANOBIND = "linux_x86_64_Nanobind"
|
||||
def CONFIG_LINUX_X86_64_PYBIND = "linux_x86_64_Pybind"
|
||||
|
||||
@Field
|
||||
def CONFIG_LINUX_AARCH64_NANOBIND = "linux_aarch64_Nanobind"
|
||||
def CONFIG_LINUX_AARCH64_PYBIND = "linux_aarch64_Pybind"
|
||||
|
||||
@Field
|
||||
def BUILD_CONFIGS = [
|
||||
@ -62,9 +62,9 @@ def BUILD_CONFIGS = [
|
||||
(TARNAME) : "TensorRT-LLM.tar.gz",
|
||||
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real",
|
||||
],
|
||||
(CONFIG_LINUX_X86_64_NANOBIND) : [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks",
|
||||
(TARNAME) : "nanobind-TensorRT-LLM.tar.gz",
|
||||
(CONFIG_LINUX_X86_64_PYBIND) : [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks",
|
||||
(TARNAME) : "pybind-TensorRT-LLM.tar.gz",
|
||||
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real",
|
||||
],
|
||||
(CONFIG_LINUX_X86_64_SINGLE_DEVICE) : [
|
||||
@ -82,9 +82,9 @@ def BUILD_CONFIGS = [
|
||||
(TARNAME) : "TensorRT-LLM-GH200.tar.gz",
|
||||
(WHEEL_ARCHS): "90-real;100-real;120-real",
|
||||
],
|
||||
(CONFIG_LINUX_AARCH64_NANOBIND): [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars WARNING_IS_ERROR=ON",
|
||||
(TARNAME) : "nanobind-TensorRT-LLM-GH200.tar.gz",
|
||||
(CONFIG_LINUX_AARCH64_PYBIND): [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON",
|
||||
(TARNAME) : "pybind-TensorRT-LLM-GH200.tar.gz",
|
||||
(WHEEL_ARCHS): "90-real;100-real;120-real",
|
||||
],
|
||||
(CONFIG_LINUX_AARCH64_LLVM) : [
|
||||
@ -542,8 +542,8 @@ def launchStages(pipeline, cpu_arch, enableFailFast, globalVars)
|
||||
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64 : CONFIG_LINUX_X86_64_VANILLA),
|
||||
"Build TRT-LLM LLVM": [LLM_DOCKER_IMAGE] + prepareLLMBuild(
|
||||
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_LLVM : CONFIG_LINUX_X86_64_LLVM),
|
||||
"Build TRT-LLM Nanobind": [LLM_DOCKER_IMAGE] + prepareLLMBuild(
|
||||
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_NANOBIND : CONFIG_LINUX_X86_64_NANOBIND),
|
||||
"Build TRT-LLM Pybind": [LLM_DOCKER_IMAGE] + prepareLLMBuild(
|
||||
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_PYBIND : CONFIG_LINUX_X86_64_PYBIND),
|
||||
]
|
||||
|
||||
if (cpu_arch == X86_64_TRIPLE) {
|
||||
|
||||
@ -65,7 +65,7 @@ def LLVM_CONFIG = "LLVM"
|
||||
LINUX_AARCH64_CONFIG = "linux_aarch64"
|
||||
|
||||
@Field
|
||||
def NANOBIND_CONFIG = "Nanobind"
|
||||
def PYBIND_CONFIG = "Pybind"
|
||||
|
||||
@Field
|
||||
def BUILD_CONFIGS = [
|
||||
@ -74,7 +74,7 @@ def BUILD_CONFIGS = [
|
||||
(SINGLE_DEVICE_CONFIG) : [(TARNAME) : "single-device-TensorRT-LLM.tar.gz"],
|
||||
(LLVM_CONFIG) : [(TARNAME) : "llvm-TensorRT-LLM.tar.gz"],
|
||||
(LINUX_AARCH64_CONFIG) : [(TARNAME) : "TensorRT-LLM-GH200.tar.gz"],
|
||||
(NANOBIND_CONFIG) : [(TARNAME) : "nanobind-TensorRT-LLM.tar.gz"],
|
||||
(PYBIND_CONFIG) : [(TARNAME) : "pybind-TensorRT-LLM.tar.gz"],
|
||||
]
|
||||
|
||||
// TODO: Move common variables to an unified location
|
||||
@ -1775,7 +1775,7 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
"A10-TensorRT-4": ["a10", "l0_a10", 4, 6],
|
||||
"A10-TensorRT-5": ["a10", "l0_a10", 5, 6],
|
||||
"A10-TensorRT-6": ["a10", "l0_a10", 6, 6],
|
||||
"A10-Nanobind": ["a10", "l0_a10_nanobind", 1, 1],
|
||||
"A10-Pybind": ["a10", "l0_a10_pybind", 1, 1],
|
||||
"A30-Triton-1": ["a30", "l0_a30", 1, 1],
|
||||
"A30-PyTorch-1": ["a30", "l0_a30", 1, 2],
|
||||
"A30-PyTorch-2": ["a30", "l0_a30", 2, 2],
|
||||
@ -1856,8 +1856,8 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
if (key.contains("llvm")) {
|
||||
config = LLVM_CONFIG
|
||||
}
|
||||
if (key.contains("Nanobind")) {
|
||||
config = NANOBIND_CONFIG
|
||||
if (key.contains("Pybind")) {
|
||||
config = PYBIND_CONFIG
|
||||
}
|
||||
runLLMTestlistOnPlatform(pipeline, values[0], values[1], config, key.contains("Perf"), key, values[2], values[3])
|
||||
}]]}
|
||||
|
||||
@ -435,7 +435,7 @@ def main(*,
|
||||
install: bool = False,
|
||||
skip_building_wheel: bool = False,
|
||||
linking_install_binary: bool = False,
|
||||
binding_type: str = "pybind",
|
||||
binding_type: str = "nanobind",
|
||||
benchmarks: bool = False,
|
||||
micro_benchmarks: bool = False,
|
||||
nvtx: bool = False,
|
||||
@ -984,8 +984,8 @@ def add_arguments(parser: ArgumentParser):
|
||||
)
|
||||
parser.add_argument("--binding_type",
|
||||
choices=["pybind", "nanobind"],
|
||||
default="pybind",
|
||||
help="Which binding type to build: pybind or nanobind")
|
||||
default="nanobind",
|
||||
help="Which binding library to use: pybind or nanobind")
|
||||
parser.add_argument("--benchmarks",
|
||||
action="store_true",
|
||||
help="Build the benchmarks for the C++ runtime")
|
||||
|
||||
@ -199,7 +199,7 @@ l0_a10:
|
||||
tests:
|
||||
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test]
|
||||
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test]
|
||||
l0_a10_nanobind:
|
||||
l0_a10_pybind:
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -211,6 +211,7 @@ l0_a10_nanobind:
|
||||
linux_distribution_name: ubuntu*
|
||||
terms:
|
||||
stage: pre_merge
|
||||
backend: tensorrt
|
||||
tests:
|
||||
- unittest/bindings
|
||||
- test_e2e.py::test_openai_chat_example[trt]
|
||||
- test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user