mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com> KVCacheManagerV2 is a new python-based implementation of the KV cache manager, featuring cleaner API, better abstraction and better code quality without the accumulated legacy.
109 lines
5.3 KiB
C++
109 lines
5.3 KiB
C++
/*
|
|
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "kvCacheManagerV2Utils.h"
|
|
#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h"
|
|
#include <nanobind/nanobind.h>
|
|
#include <nanobind/stl/vector.h>
|
|
|
|
namespace nb = nanobind;
|
|
|
|
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|
|
{
|
|
|
|
void KVCacheManagerV2UtilsBindings::initBindings(nb::module_& module)
|
|
{
|
|
// Bind DiskAddress struct
|
|
nb::class_<DiskAddress>(module, "DiskAddress")
|
|
.def(nb::init<int, ssize_t>(), nb::arg("fd"), nb::arg("pos"))
|
|
.def_rw("fd", &DiskAddress::fd)
|
|
.def_rw("pos", &DiskAddress::pos);
|
|
|
|
// Bind Task template instantiations
|
|
nb::class_<Task<DiskAddress, DiskAddress>>(module, "DiskToDiskTask")
|
|
.def(nb::init<DiskAddress, DiskAddress>(), nb::arg("dst"), nb::arg("src"))
|
|
.def_rw("dst", &Task<DiskAddress, DiskAddress>::dst)
|
|
.def_rw("src", &Task<DiskAddress, DiskAddress>::src);
|
|
|
|
nb::class_<Task<MemAddress, DiskAddress>>(module, "DiskToHostTask")
|
|
.def(nb::init<MemAddress, DiskAddress>(), nb::arg("dst"), nb::arg("src"))
|
|
.def_rw("dst", &Task<MemAddress, DiskAddress>::dst)
|
|
.def_rw("src", &Task<MemAddress, DiskAddress>::src);
|
|
|
|
nb::class_<Task<DiskAddress, MemAddress>>(module, "HostToDiskTask")
|
|
.def(nb::init<DiskAddress, MemAddress>(), nb::arg("dst"), nb::arg("src"))
|
|
.def_rw("dst", &Task<DiskAddress, MemAddress>::dst)
|
|
.def_rw("src", &Task<DiskAddress, MemAddress>::src);
|
|
|
|
nb::class_<Task<MemAddress, MemAddress>>(module, "MemToMemTask")
|
|
.def(nb::init<MemAddress, MemAddress>(), nb::arg("dst"), nb::arg("src"))
|
|
.def_rw("dst", &Task<MemAddress, MemAddress>::dst)
|
|
.def_rw("src", &Task<MemAddress, MemAddress>::src);
|
|
|
|
// Bind copy functions
|
|
module.def(
|
|
"copy_disk_to_disk",
|
|
[](std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
|
|
{ return copyDiskToDisk(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
|
|
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
|
|
"Copy data from disk to disk using CUDA host function");
|
|
|
|
module.def(
|
|
"copy_disk_to_host",
|
|
[](std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
|
|
{ return copyDiskToHost(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
|
|
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
|
|
"Copy data from disk to host using CUDA host function");
|
|
|
|
module.def(
|
|
"copy_host_to_disk",
|
|
[](std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
|
|
{ return copyHostToDisk(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
|
|
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
|
|
"Copy data from host to disk using CUDA host function");
|
|
|
|
module.def(
|
|
"copy_host_to_host",
|
|
[](std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
|
|
{ return copyHostToHost(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
|
|
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
|
|
"Copy data from host to host using CUDA host function");
|
|
|
|
module.def(
|
|
"copy_host_to_device",
|
|
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
|
|
{ return copyHostToDevice(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
|
|
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
|
|
"Copy data from host to device using CUDA kernels");
|
|
|
|
module.def(
|
|
"copy_device_to_host",
|
|
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
|
|
{ return copyDeviceToHost(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
|
|
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
|
|
"Copy data from device to host using CUDA kernels");
|
|
|
|
module.def(
|
|
"copy_device_to_device",
|
|
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
|
|
{ return copyDeviceToDevice(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
|
|
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
|
|
"Copy data from device to device using CUDA kernels");
|
|
}
|
|
|
|
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
|