From 6f07fa81d7cf13a425851607996abc8a69f92b6b Mon Sep 17 00:00:00 2001 From: Yao Yao Date: Sat, 24 Jan 2026 17:48:39 +0800 Subject: [PATCH] [TRTLLM-7738][feat] Adding implementation of KVCacheManagerV2 (#10736) Signed-off-by: Yao Yao KVCacheManagerV2 is a new python-based implementation of the KV cache manager, featuring cleaner API, better abstraction and better code quality without the accumulated legacy. --- .gitignore | 2 + cpp/tensorrt_llm/batch_manager/CMakeLists.txt | 2 + .../batch_manager/kvCacheManagerV2Utils.cpp | 163 +++ .../batch_manager/kvCacheManagerV2Utils.cu | 182 +++ .../batch_manager/kvCacheManagerV2Utils.h | 51 + cpp/tensorrt_llm/nanobind/CMakeLists.txt | 1 + .../batch_manager/kvCacheManagerV2Utils.cpp | 108 ++ .../batch_manager/kvCacheManagerV2Utils.h | 31 + cpp/tensorrt_llm/nanobind/bindings.cpp | 4 + .../nanobind/common/customCasters.h | 19 + cpp/tensorrt_llm/pybind/CMakeLists.txt | 1 + .../batch_manager/kvCacheManagerV2Utils.cpp | 111 ++ .../batch_manager/kvCacheManagerV2Utils.h | 29 + cpp/tensorrt_llm/pybind/bindings.cpp | 4 + requirements.txt | 3 + scripts/build_wheel.py | 54 +- setup.py | 93 +- tensorrt_llm/runtime/__init__.py | 14 + .../runtime/kv_cache_manager_v2/Makefile | 51 + .../runtime/kv_cache_manager_v2/__init__.py | 67 ++ .../runtime/kv_cache_manager_v2/__init__.pyi | 213 ++++ .../kv_cache_manager_v2/_block_radix_tree.py | 437 +++++++ .../runtime/kv_cache_manager_v2/_common.py | 85 ++ .../runtime/kv_cache_manager_v2/_config.py | 148 +++ .../kv_cache_manager_v2/_copy_engine.py | 374 ++++++ .../kv_cache_manager_v2/_core/__init__.py | 20 + .../kv_cache_manager_v2/_core/_kv_cache.py | 1048 +++++++++++++++++ .../_core/_kv_cache_manager.py | 218 ++++ .../kv_cache_manager_v2/_cuda_virt_mem.py | 184 +++ .../_eviction_controller/__init__.py | 23 + .../_eviction_controller.py | 228 ++++ .../kv_cache_manager_v2/_exceptions.py | 60 + .../_life_cycle_registry.py | 90 ++ .../runtime/kv_cache_manager_v2/_page.py | 463 ++++++++ .../kv_cache_manager_v2/_storage/__init__.py | 19 + .../kv_cache_manager_v2/_storage/_config.py | 225 ++++ .../kv_cache_manager_v2/_storage/_core.py | 936 +++++++++++++++ .../kv_cache_manager_v2/_storage_manager.py | 518 ++++++++ .../runtime/kv_cache_manager_v2/_utils.py | 945 +++++++++++++++ .../kv_cache_manager_v2/mypy_mypyc.ini | 51 + .../kv_cache_manager_v2/rawref/README.md | 140 +++ .../kv_cache_manager_v2/rawref/__init__.py | 35 + .../kv_cache_manager_v2/rawref/__init__.pyi | 74 ++ .../kv_cache_manager_v2/rawref/rawrefmodule.c | 238 ++++ .../kv_cache_manager_v2/rawref/setup.py | 28 + .../kv_cache_manager_v2/rawref/test_rawref.py | 269 +++++ .../kv_cache_manager_v2/setup_mypyc.py | 138 +++ tests/integration/defs/agg_unit_mem_df.csv | 6 + .../integration/test_lists/test-db/l0_a10.yml | 1 + .../test_lists/test-db/l0_b200.yml | 1 + .../test_lists/test-db/l0_h100.yml | 1 + .../kv_cache_manager_v2_tests/fake_engine.py | 206 ++++ .../kv_cache_manager_v2_tests/kernels.py | 336 ++++++ .../test_kv_cache_manager_v2.py | 711 +++++++++++ 54 files changed, 9442 insertions(+), 17 deletions(-) create mode 100644 cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp create mode 100644 cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu create mode 100644 cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.cpp create mode 100644 cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.h create mode 100644 cpp/tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.cpp create mode 100644 cpp/tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.h create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/Makefile create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/__init__.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_block_radix_tree.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_common.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_config.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_copy_engine.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_core/__init__.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_cuda_virt_mem.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_eviction_controller/__init__.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_eviction_controller/_eviction_controller.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_exceptions.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_life_cycle_registry.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_page.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_storage/__init__.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_config.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_core.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_storage_manager.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/_utils.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/mypy_mypyc.ini create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/rawref/README.md create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/rawref/__init__.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/rawref/__init__.pyi create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/rawref/rawrefmodule.c create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/rawref/setup.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/rawref/test_rawref.py create mode 100644 tensorrt_llm/runtime/kv_cache_manager_v2/setup_mypyc.py create mode 100644 tests/unittest/kv_cache_manager_v2_tests/fake_engine.py create mode 100644 tests/unittest/kv_cache_manager_v2_tests/kernels.py create mode 100755 tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py diff --git a/.gitignore b/.gitignore index c588d39d9b..d409a49f48 100644 --- a/.gitignore +++ b/.gitignore @@ -52,6 +52,8 @@ tensorrt_llm/pg_utils_bindings.*.so tensorrt_llm/flash_mla/ tensorrt_llm/flash_mla_cpp_tllm.*.so tensorrt_llm/flash_mla_cpp_tllm.pyi +tensorrt_llm/runtime/kv_cache_manager_v2/**/*.so +**/*__mypyc*.so tensorrt_llm/scripts *docs/cpp_docs* *docs/source/_cpp_gen* diff --git a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt index 5dad6906ee..f62c2aaf7f 100644 --- a/cpp/tensorrt_llm/batch_manager/CMakeLists.txt +++ b/cpp/tensorrt_llm/batch_manager/CMakeLists.txt @@ -36,6 +36,8 @@ set(SRCS kvCacheManager.cpp kvCacheEventManager.cpp kvCacheTransferManager.cpp + kvCacheManagerV2Utils.cpp + kvCacheManagerV2Utils.cu llmRequest.cpp logitsPostProcessor.cpp loraBuffers.cpp diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp new file mode 100644 index 0000000000..fb369b0f0f --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cpp @@ -0,0 +1,163 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h" +#include "tensorrt_llm/common/logger.h" +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 +{ + +template +bool loopedReadWrite(Func&& func, ssize_t size) noexcept +{ + ssize_t count = 0; + while (count < size) + { + ssize_t bytes = func(count); + if (bytes <= 0) + { + if (errno == EINTR) + { + continue; // Retry on interrupt + } + TLLM_LOG_ERROR("Disk read/write failed: %s\n", strerror(errno)); + return false; + } + count += bytes; + } + assert(count == size); + return true; +} + +bool writeAll(int fd, ssize_t pos, void const* data, ssize_t size) noexcept +{ + return loopedReadWrite([=](ssize_t finished) + { return pwrite(fd, static_cast(data) + finished, size - finished, pos + finished); }, + size); +} + +bool readAll(int fd, ssize_t pos, void* data, ssize_t size) noexcept +{ + return loopedReadWrite([=](ssize_t finished) + { return pread(fd, static_cast(data) + finished, size - finished, pos + finished); }, + size); +} + +template +struct UserData +{ + std::vector> tasks; + ssize_t numBytes; +}; + +CUDA_CB void hostFnDiskToDiskCopy(void* userData) noexcept +{ + // @TODO: enable multi-threading with a thread pool + using Data = UserData; + auto const data = std::unique_ptr(static_cast(userData)); + std::vector buffer(data->numBytes); + bool success = true; + for (auto const& t : data->tasks) + { + success = success && readAll(t.src.fd, t.src.pos, buffer.data(), data->numBytes); + success = success && writeAll(t.dst.fd, t.dst.pos, buffer.data(), data->numBytes); + } + if (!success) + { + TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnDiskToDiskCopy failed.\n"); + } +} + +CUDA_CB void hostFnDiskToHostCopy(void* userData) noexcept +{ + // @TODO: enable multi-threading with a thread pool + using Data = UserData; + auto const data = std::unique_ptr(static_cast(userData)); + bool success = true; + for (auto const& t : data->tasks) + { + success = success && readAll(t.src.fd, t.src.pos, reinterpret_cast(t.dst), data->numBytes); + } + if (!success) + { + TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnDiskToHostCopy failed.\n"); + } +} + +CUDA_CB void hostFnHostToDiskCopy(void* userData) noexcept +{ + // @TODO: enable multi-threading with a thread pool + using Data = UserData; + auto const data = std::unique_ptr(static_cast(userData)); + bool success = true; + for (auto const& t : data->tasks) + { + success = success && writeAll(t.dst.fd, t.dst.pos, reinterpret_cast(t.src), data->numBytes); + } + if (!success) + { + TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnHostToDiskCopy failed.\n"); + } +} + +CUDA_CB void hostFnHostToHostCopy(void* userData) noexcept +{ + // @TODO: enable multi-threading with a thread pool + using Data = UserData; + auto const data = std::unique_ptr(static_cast(userData)); + for (auto const& t : data->tasks) + { + memcpy(reinterpret_cast(t.dst), reinterpret_cast(t.src), data->numBytes); + } +} + +CUresult copyDiskToDisk(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept +{ + using Data = UserData; + auto data = std::make_unique(Data{std::move(tasks), numBytes}); + return cuLaunchHostFunc(stream, hostFnDiskToDiskCopy, data.release()); +} + +CUresult copyDiskToHost(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept +{ + using Data = UserData; + auto data = std::make_unique(Data{std::move(tasks), numBytes}); + return cuLaunchHostFunc(stream, hostFnDiskToHostCopy, data.release()); +} + +CUresult copyHostToDisk(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept +{ + using Data = UserData; + auto data = std::make_unique(Data{std::move(tasks), numBytes}); + return cuLaunchHostFunc(stream, hostFnHostToDiskCopy, data.release()); +} + +CUresult copyHostToHost(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept +{ + using Data = UserData; + auto data = std::make_unique(Data{std::move(tasks), numBytes}); + return cuLaunchHostFunc(stream, hostFnHostToHostCopy, data.release()); +} + +} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu new file mode 100644 index 0000000000..4a134e5d08 --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu @@ -0,0 +1,182 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kvCacheManagerV2Utils.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include +#include +#include +#include + +namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 +{ +using Grain = uint4; +constexpr uint32_t ctaSize = 128; +constexpr uint32_t nbBufs = 4; +constexpr uint32_t grainBytes = sizeof(Grain); + +using MMTask = Task; + +__device__ __host__ inline uint32_t divUp(uint32_t a, uint32_t b) +{ + return (a + b - 1) / b; +} + +template +__global__ void batchedCopy(std::array const __grid_constant__ tasks, uint32_t nbBytes) +{ +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("griddepcontrol.launch_dependents;\n"); +#endif + assert(nbBytes % sizeof(Grain) == 0); + __shared__ Grain data[nbBufs][ctaSize]; + + uint32_t const nbTasks = gridDim.y; + assert(nbTasks <= N); + auto const& task = tasks[blockIdx.y]; + uint32_t const nbSplits = gridDim.x; + uint32_t const idxSplit = blockIdx.x; + uint32_t const tid = threadIdx.x; + + constexpr uint32_t bytesPerIter = grainBytes * ctaSize; + + uint32_t const totalIters = divUp(nbBytes, bytesPerIter); + uint32_t const maxItersPerCta = divUp(totalIters, nbSplits); + uint32_t const idxGrainBeg = ctaSize * maxItersPerCta * idxSplit + tid; + uint32_t const idxGrainEnd = std::min(idxGrainBeg + ctaSize * maxItersPerCta, nbBytes / grainBytes); + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("griddepcontrol.wait;\n"); +#endif + for (uint32_t i = 0; i < maxItersPerCta + nbBufs; i++) + { + uint32_t const idxBuf = i % nbBufs; + if (i >= nbBufs) + { + uint32_t const stIter = i - nbBufs; + assert(idxBuf == (stIter % nbBufs)); + Grain const& src = data[idxBuf][tid]; + uint32_t const idxGrain = idxGrainBeg + ctaSize * stIter; + Grain& dst = reinterpret_cast(task.dst)[idxGrain]; + asm volatile("cp.async.wait_group %0;\n" ::"n"(nbBufs - 1) : "memory"); + if (idxGrain < idxGrainEnd) + { + dst = src; + } + } + uint32_t const ldIter = i; + Grain* const dst = &data[idxBuf][tid]; + uint32_t const idxGrain = idxGrainBeg + ctaSize * ldIter; + Grain const* const src = &reinterpret_cast(task.src)[idxGrain]; + if (idxGrain < idxGrainEnd) + { + uint32_t const size = grainBytes; + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)), + "l"(src), "n"(grainBytes), "r"(size) + : "memory"); + } + asm volatile("cp.async.commit_group;\n" : : : "memory"); + } +} + +template +CUresult launchBatchedCopyImpl( + bool lowBandwidth, MMTask const* tasks, uint32_t nbTasks, uint32_t nbBytes, cudaStream_t stream) +{ + TLLM_CHECK(nbTasks <= N); + TLLM_CHECK_WITH_INFO( + nbBytes % sizeof(Grain) == 0, "Not implemented case: nbBytes = %d must be a multiple of 16.", nbBytes); + std::array const* pTasks; + std::array tmp; + if (nbTasks < N) + { + std::copy_n(tasks, nbTasks, tmp.begin()); + pTasks = &tmp; + } + else + { + pTasks = reinterpret_cast const*>(tasks); + } + uint32_t const nbSplits = lowBandwidth ? 1 : divUp(nbBytes, grainBytes * ctaSize * 2); + void* args[] = {(void*) pTasks, (void*) &nbBytes}; + static CUkernel const kernel = [] -> CUkernel + { + cudaKernel_t kernel = nullptr; + TLLM_CUDA_CHECK(cudaGetKernel(&kernel, reinterpret_cast(&batchedCopy))); + return kernel; + }(); + return common::CUDADriverWrapper::getInstance()->cuLaunchKernel(reinterpret_cast(kernel), nbSplits, + nbTasks, 1, // gridDimX, gridDimY, gridDimZ + ctaSize, 1, 1, // blockDimX, blockDimY, blockDimZ + 0, // sharedMemBytes + stream, args, nullptr); +} + +// When bandwidth is low, e.g. when host memory is involved, we avoid splitting as fewer CTAs should be enough to +// saturate the bandwidth. +CUresult launchBatchedCopy(bool lowBandwidth, std::vector const& tasks, uint32_t nbBytes, cudaStream_t stream) +{ + constexpr uint32_t maxN = 256; + uint32_t const nbWholeBatches = tasks.size() / maxN; + for (uint32_t i = 0; i < nbWholeBatches; i++) + { + CUresult const err = launchBatchedCopyImpl(lowBandwidth, tasks.data() + maxN * i, maxN, nbBytes, stream); + if (err != CUDA_SUCCESS) + { + return err; + } + } + { + auto const* const pTasks = tasks.data() + maxN * nbWholeBatches; + auto const batchSize = tasks.size() % maxN; + if (batchSize == 0) + { + return CUDA_SUCCESS; + } + if (batchSize > maxN / 2) + { + return launchBatchedCopyImpl(lowBandwidth, pTasks, batchSize, nbBytes, stream); + } + if (batchSize > maxN / 4) + { + return launchBatchedCopyImpl(lowBandwidth, pTasks, batchSize, nbBytes, stream); + } + if (batchSize > maxN / 8) + { + return launchBatchedCopyImpl(lowBandwidth, pTasks, batchSize, nbBytes, stream); + } + return launchBatchedCopyImpl(lowBandwidth, pTasks, batchSize, nbBytes, stream); + } +} + +CUresult copyHostToDevice(std::vector const& tasks, ssize_t numBytes, CUstream stream) noexcept +{ + return launchBatchedCopy(true, tasks, numBytes, stream); +} + +CUresult copyDeviceToHost(std::vector const& tasks, ssize_t numBytes, CUstream stream) noexcept +{ + return launchBatchedCopy(true, tasks, numBytes, stream); +} + +CUresult copyDeviceToDevice(std::vector const& tasks, ssize_t numBytes, CUstream stream) noexcept +{ + return launchBatchedCopy(false, tasks, numBytes, stream); +} + +} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h new file mode 100644 index 0000000000..2acb81e522 --- /dev/null +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 +{ +struct DiskAddress +{ + int fd; + ssize_t pos; +}; + +using MemAddress = std::uintptr_t; + +template +struct Task +{ + DstAddr dst; + SrcAddr src; +}; + +CUresult copyDiskToDisk(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept; +CUresult copyDiskToHost(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept; +CUresult copyHostToDisk(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept; +CUresult copyHostToHost(std::vector> tasks, ssize_t numBytes, CUstream stream) noexcept; +CUresult copyHostToDevice( + std::vector> const& tasks, ssize_t numBytes, CUstream stream) noexcept; +CUresult copyDeviceToHost( + std::vector> const& tasks, ssize_t numBytes, CUstream stream) noexcept; +CUresult copyDeviceToDevice( + std::vector> const& tasks, ssize_t numBytes, CUstream stream) noexcept; +} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/nanobind/CMakeLists.txt b/cpp/tensorrt_llm/nanobind/CMakeLists.txt index 9c64b3705d..a47f8dd1d4 100755 --- a/cpp/tensorrt_llm/nanobind/CMakeLists.txt +++ b/cpp/tensorrt_llm/nanobind/CMakeLists.txt @@ -10,6 +10,7 @@ set(SRCS batch_manager/cacheTransceiver.cpp batch_manager/kvCacheConnector.cpp batch_manager/kvCacheManager.cpp + batch_manager/kvCacheManagerV2Utils.cpp batch_manager/llmRequest.cpp common/tllmExceptions.cpp executor/bindings.cpp diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.cpp new file mode 100644 index 0000000000..0985d0299e --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.cpp @@ -0,0 +1,108 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kvCacheManagerV2Utils.h" +#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h" +#include +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 +{ + +void KVCacheManagerV2UtilsBindings::initBindings(nb::module_& module) +{ + // Bind DiskAddress struct + nb::class_(module, "DiskAddress") + .def(nb::init(), nb::arg("fd"), nb::arg("pos")) + .def_rw("fd", &DiskAddress::fd) + .def_rw("pos", &DiskAddress::pos); + + // Bind Task template instantiations + nb::class_>(module, "DiskToDiskTask") + .def(nb::init(), nb::arg("dst"), nb::arg("src")) + .def_rw("dst", &Task::dst) + .def_rw("src", &Task::src); + + nb::class_>(module, "DiskToHostTask") + .def(nb::init(), nb::arg("dst"), nb::arg("src")) + .def_rw("dst", &Task::dst) + .def_rw("src", &Task::src); + + nb::class_>(module, "HostToDiskTask") + .def(nb::init(), nb::arg("dst"), nb::arg("src")) + .def_rw("dst", &Task::dst) + .def_rw("src", &Task::src); + + nb::class_>(module, "MemToMemTask") + .def(nb::init(), nb::arg("dst"), nb::arg("src")) + .def_rw("dst", &Task::dst) + .def_rw("src", &Task::src); + + // Bind copy functions + module.def( + "copy_disk_to_disk", + [](std::vector> tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyDiskToDisk(std::move(tasks), numBytes, reinterpret_cast(stream)); }, + nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard(), + "Copy data from disk to disk using CUDA host function"); + + module.def( + "copy_disk_to_host", + [](std::vector> tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyDiskToHost(std::move(tasks), numBytes, reinterpret_cast(stream)); }, + nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard(), + "Copy data from disk to host using CUDA host function"); + + module.def( + "copy_host_to_disk", + [](std::vector> tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyHostToDisk(std::move(tasks), numBytes, reinterpret_cast(stream)); }, + nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard(), + "Copy data from host to disk using CUDA host function"); + + module.def( + "copy_host_to_host", + [](std::vector> tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyHostToHost(std::move(tasks), numBytes, reinterpret_cast(stream)); }, + nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard(), + "Copy data from host to host using CUDA host function"); + + module.def( + "copy_host_to_device", + [](std::vector> const& tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyHostToDevice(tasks, numBytes, reinterpret_cast(stream)); }, + nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard(), + "Copy data from host to device using CUDA kernels"); + + module.def( + "copy_device_to_host", + [](std::vector> const& tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyDeviceToHost(tasks, numBytes, reinterpret_cast(stream)); }, + nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard(), + "Copy data from device to host using CUDA kernels"); + + module.def( + "copy_device_to_device", + [](std::vector> const& tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyDeviceToDevice(tasks, numBytes, reinterpret_cast(stream)); }, + nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard(), + "Copy data from device to device using CUDA kernels"); +} + +} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.h b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.h new file mode 100644 index 0000000000..ad10bf537d --- /dev/null +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.h @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace nb = nanobind; + +namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 +{ +class KVCacheManagerV2UtilsBindings +{ +public: + static void initBindings(nb::module_& module); +}; +} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/nanobind/bindings.cpp b/cpp/tensorrt_llm/nanobind/bindings.cpp index 06ea1d7a2a..327d1f713d 100644 --- a/cpp/tensorrt_llm/nanobind/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/bindings.cpp @@ -38,6 +38,7 @@ #include "tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h" #include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h" #include "tensorrt_llm/nanobind/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.h" #include "tensorrt_llm/nanobind/batch_manager/llmRequest.h" #include "tensorrt_llm/nanobind/common/tllmExceptions.h" #include "tensorrt_llm/nanobind/executor/bindings.h" @@ -131,6 +132,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m) auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings"); auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings"); auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings"); + auto mInternalBatchManagerKvCacheV2Utils + = mInternalBatchManager.def_submodule("kv_cache_manager_v2_utils", "KV Cache Manager V2 Utils bindings"); auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings"); auto mExceptions = m.def_submodule("exceptions", "Exceptions internal bindings"); @@ -502,6 +505,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m) tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager); tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager); tb::CacheTransceiverBindings::initBindings(mInternalBatchManager); + tb::kv_cache_manager_v2::KVCacheManagerV2UtilsBindings::initBindings(mInternalBatchManagerKvCacheV2Utils); auto mInternalAlgorithms = mInternal.def_submodule("algorithms", "Algorithms internal bindings"); tpb::algorithms::initBindings(mInternalAlgorithms); diff --git a/cpp/tensorrt_llm/nanobind/common/customCasters.h b/cpp/tensorrt_llm/nanobind/common/customCasters.h index 432ce5c26b..8c202b9387 100644 --- a/cpp/tensorrt_llm/nanobind/common/customCasters.h +++ b/cpp/tensorrt_llm/nanobind/common/customCasters.h @@ -315,5 +315,24 @@ struct type_caster throw std::runtime_error("from_cpp for torch::ScalarType is not implemented"); } }; + +template <> +class type_caster +{ +public: + NB_TYPE_CASTER(CUstream, const_name("int")); + + bool from_python([[maybe_unused]] handle src, uint8_t flags, cleanup_list* cleanup) + { + value = reinterpret_cast(PyLong_AsVoidPtr(src.ptr())); + return true; + return true; + } + + static handle from_cpp(CUstream const& src, rv_policy /* policy */, cleanup_list* /* cleanup */) + { + return PyLong_FromVoidPtr(src); + } +}; } // namespace detail } // namespace NB_NAMESPACE diff --git a/cpp/tensorrt_llm/pybind/CMakeLists.txt b/cpp/tensorrt_llm/pybind/CMakeLists.txt index 71c33a479e..9caa0a44c4 100755 --- a/cpp/tensorrt_llm/pybind/CMakeLists.txt +++ b/cpp/tensorrt_llm/pybind/CMakeLists.txt @@ -10,6 +10,7 @@ set(SRCS batch_manager/cacheTransceiver.cpp batch_manager/kvCacheConnector.cpp batch_manager/kvCacheManager.cpp + batch_manager/kvCacheManagerV2Utils.cpp batch_manager/llmRequest.cpp executor/bindings.cpp executor/executor.cpp diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.cpp new file mode 100644 index 0000000000..e85868ccf3 --- /dev/null +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.cpp @@ -0,0 +1,111 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kvCacheManagerV2Utils.h" + +#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h" + +#include +#include +#include + +namespace py = pybind11; + +namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 +{ + +void KVCacheManagerV2UtilsBindings::initBindings(py::module_& module) +{ + // Bind DiskAddress struct + py::class_(module, "DiskAddress") + .def(py::init(), py::arg("fd"), py::arg("pos")) + .def_readwrite("fd", &DiskAddress::fd) + .def_readwrite("pos", &DiskAddress::pos); + + // Bind Task template instantiations + py::class_>(module, "DiskToDiskTask") + .def(py::init(), py::arg("dst"), py::arg("src")) + .def_readwrite("dst", &Task::dst) + .def_readwrite("src", &Task::src); + + py::class_>(module, "DiskToHostTask") + .def(py::init(), py::arg("dst"), py::arg("src")) + .def_readwrite("dst", &Task::dst) + .def_readwrite("src", &Task::src); + + py::class_>(module, "HostToDiskTask") + .def(py::init(), py::arg("dst"), py::arg("src")) + .def_readwrite("dst", &Task::dst) + .def_readwrite("src", &Task::src); + + py::class_>(module, "MemToMemTask") + .def(py::init(), py::arg("dst"), py::arg("src")) + .def_readwrite("dst", &Task::dst) + .def_readwrite("src", &Task::src); + + // Bind copy functions + module.def( + "copy_disk_to_disk", + [](std::vector> tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyDiskToDisk(std::move(tasks), numBytes, reinterpret_cast(stream)); }, + py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard(), + "Copy data from disk to disk using CUDA host function"); + + module.def( + "copy_disk_to_host", + [](std::vector> tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyDiskToHost(std::move(tasks), numBytes, reinterpret_cast(stream)); }, + py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard(), + "Copy data from disk to host using CUDA host function"); + + module.def( + "copy_host_to_disk", + [](std::vector> tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyHostToDisk(std::move(tasks), numBytes, reinterpret_cast(stream)); }, + py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard(), + "Copy data from host to disk using CUDA host function"); + + module.def( + "copy_host_to_host", + [](std::vector> tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyHostToHost(std::move(tasks), numBytes, reinterpret_cast(stream)); }, + py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard(), + "Copy data from host to host using CUDA host function"); + + module.def( + "copy_host_to_device", + [](std::vector> const& tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyHostToDevice(tasks, numBytes, reinterpret_cast(stream)); }, + py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard(), + "Copy data from host to device using CUDA kernels"); + + module.def( + "copy_device_to_host", + [](std::vector> const& tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyDeviceToHost(tasks, numBytes, reinterpret_cast(stream)); }, + py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard(), + "Copy data from device to host using CUDA kernels"); + + module.def( + "copy_device_to_device", + [](std::vector> const& tasks, ssize_t numBytes, uintptr_t stream) -> int + { return copyDeviceToDevice(tasks, numBytes, reinterpret_cast(stream)); }, + py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard(), + "Copy data from device to device using CUDA kernels"); +} + +} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.h b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.h new file mode 100644 index 0000000000..6dcc642479 --- /dev/null +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 +{ +class KVCacheManagerV2UtilsBindings +{ +public: + static void initBindings(pybind11::module_& module); +}; +} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2 diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index cb4f34b722..4dba98dba0 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -33,6 +33,7 @@ #include "tensorrt_llm/pybind/batch_manager/cacheTransceiver.h" #include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h" #include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.h" #include "tensorrt_llm/pybind/batch_manager/llmRequest.h" #include "tensorrt_llm/pybind/common/tllmExceptions.h" #include "tensorrt_llm/pybind/executor/bindings.h" @@ -124,6 +125,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings"); auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings"); auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings"); + auto mInternalBatchManagerKvCacheV2Utils + = mInternalBatchManager.def_submodule("kv_cache_manager_v2_utils", "KV Cache Manager V2 Utils bindings"); auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings"); auto mExceptions = m.def_submodule("exceptions", "Exceptions internal bindings"); @@ -490,6 +493,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager); tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager); tb::CacheTransceiverBindings::initBindings(mInternalBatchManager); + tb::kv_cache_manager_v2::KVCacheManagerV2UtilsBindings::initBindings(mInternalBatchManagerKvCacheV2Utils); auto mInternalAlgorithms = mInternal.def_submodule("algorithms", "Algorithms internal bindings"); tpb::algorithms::initBindings(mInternalAlgorithms); diff --git a/requirements.txt b/requirements.txt index 689d8c2b9d..da6c9a2107 100644 --- a/requirements.txt +++ b/requirements.txt @@ -78,3 +78,6 @@ apache-tvm-ffi==0.1.6 # used for reduce nvidia-cutlass-dsl host overhead torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf mistral-common==1.8.6 torchao>=0.14.1 +cuda-core +llist +dynamic_path_manager diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index 5d973d5fe0..8e2f6fd1c5 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -457,6 +457,43 @@ def generate_python_stubs_windows(binding_type: str, venv_python: Path, (pkg_dir / stubgen).unlink() +def build_kv_cache_manager_v2(project_dir, venv_python, use_mypyc=False): + print("-- Building kv_cache_manager_v2...") + kv_cache_mgr_dir = project_dir / "tensorrt_llm/runtime/kv_cache_manager_v2" + runtime_dir = project_dir / "tensorrt_llm/runtime" + + # Clean up any existing mypyc artifacts in runtime directory to prevent stale inclusion + # when switching from --mypyc to standard build + if not use_mypyc: + for so_file in runtime_dir.glob("*__mypyc*.so"): + print(f"Removing stale mypyc artifact: {so_file}") + so_file.unlink() + + # Also clean up any .so files inside kv_cache_manager_v2 + for so_file in kv_cache_mgr_dir.rglob("*.so"): + print(f"Removing stale artifact: {so_file}") + so_file.unlink() + + # Build rawref + print("-- Building kv_cache_manager_v2 rawref extension...") + rawref_dir = kv_cache_mgr_dir / "rawref" + build_run(f'"{venv_python}" setup.py build_ext --inplace', cwd=rawref_dir) + + if use_mypyc: + # Build mypyc + print("-- Building kv_cache_manager_v2 mypyc extensions...") + # setup_mypyc.py is in kv_cache_manager_v2 but executed from runtime dir + setup_mypyc = kv_cache_mgr_dir / "setup_mypyc.py" + build_run(f'"{venv_python}" "{setup_mypyc}" build_ext --inplace', + cwd=runtime_dir) + + # Verify that the shared library was generated + if not list(runtime_dir.glob("*__mypyc*.so")): + raise RuntimeError( + "Failed to build kv_cache_manager_v2: no shared library generated." + ) + + def main(*, build_type: str = "Release", generator: str = "", @@ -487,7 +524,8 @@ def main(*, skip_stubs: bool = False, generate_fmha: bool = False, no_venv: bool = False, - nvrtc_dynamic_linking: bool = False): + nvrtc_dynamic_linking: bool = False, + mypyc: bool = False): if clean: clean_wheel = True @@ -967,6 +1005,8 @@ def main(*, nixl_root is not None or mooncake_root is not None, binding_lib_file_name) + build_kv_cache_manager_v2(project_dir, venv_python, use_mypyc=mypyc) + if not skip_building_wheel: if dist_dir is None: dist_dir = project_dir / "build" @@ -988,6 +1028,15 @@ def main(*, build_run( f'\"{venv_python}\" -m build {project_dir} --skip-dependency-check {extra_wheel_build_args} --no-isolation --wheel --outdir "{dist_dir}"' ) + env = os.environ.copy() + if mypyc: + env["TRTLLM_ENABLE_MYPYC"] = "1" + else: + env["TRTLLM_ENABLE_MYPYC"] = "0" + + build_run( + f'\"{venv_python}\" -m build {project_dir} --skip-dependency-check --no-isolation --wheel --outdir "{dist_dir}"', + env=env) if install: build_run(f"\"{sys.executable}\" -m pip install -e .[devel]") @@ -1135,6 +1184,9 @@ def add_arguments(parser: ArgumentParser): "--nvrtc_dynamic_linking", action="store_true", help="Link against dynamic NVRTC libraries instead of static ones") + parser.add_argument("--mypyc", + action="store_true", + help="Compile kv_cache_manager_v2 with mypyc") if __name__ == "__main__": diff --git a/setup.py b/setup.py index ae2be9d9bc..83aec22469 100644 --- a/setup.py +++ b/setup.py @@ -110,21 +110,43 @@ if on_windows: ] else: package_data = [ - 'bin/executorWorker', 'libs/libtensorrt_llm.so', 'libs/libth_common.so', + 'bin/executorWorker', + 'libs/libtensorrt_llm.so', + 'libs/libth_common.so', 'libs/libnvinfer_plugin_tensorrt_llm.so', - 'libs/libtensorrt_llm_ucx_wrapper.so', 'libs/libdecoder_attention_0.so', - 'libs/libtensorrt_llm_nixl_wrapper.so', 'libs/nixl/**/*', + 'libs/libtensorrt_llm_ucx_wrapper.so', + 'libs/libdecoder_attention_0.so', + 'libs/libtensorrt_llm_nixl_wrapper.so', + 'libs/nixl/**/*', 'tensorrt_llm_transfer_agent_binding*.so', 'tensorrt_llm_transfer_agent_binding.pyi', - 'libs/libtensorrt_llm_mooncake_wrapper.so', 'libs/ucx/**/*', - 'libs/libpg_utils.so', 'libs/libdecoder_attention_1.so', - 'libs/nvshmem/License.txt', 'libs/nvshmem/nvshmem_bootstrap_uid.so.3', - 'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so', - 'deep_ep/LICENSE', 'deep_ep/*.py', 'deep_ep_cpp_tllm.*.so', - "include/**/*", 'deep_gemm/LICENSE', 'deep_gemm/include/**/*', - 'deep_gemm/*.py', 'deep_gemm_cpp_tllm.*.so', - 'scripts/install_tensorrt.sh', 'flash_mla/LICENSE', 'flash_mla/*.py', - 'flash_mla_cpp_tllm.*.so' + 'libs/libtensorrt_llm_mooncake_wrapper.so', + 'libs/ucx/**/*', + 'libs/libpg_utils.so', + 'libs/libdecoder_attention_1.so', + 'libs/nvshmem/License.txt', + 'libs/nvshmem/nvshmem_bootstrap_uid.so.3', + 'libs/nvshmem/nvshmem_transport_ibgda.so.103', + 'bindings.*.so', + 'deep_ep/LICENSE', + 'deep_ep/*.py', + 'deep_ep_cpp_tllm.*.so', + "include/**/*", + 'deep_gemm/LICENSE', + 'deep_gemm/include/**/*', + 'deep_gemm/*.py', + 'deep_gemm_cpp_tllm.*.so', + 'scripts/install_tensorrt.sh', + 'flash_mla/LICENSE', + 'flash_mla/*.py', + 'flash_mla_cpp_tllm.*.so', + 'runtime/kv_cache_manager_v2/*.so', + 'runtime/kv_cache_manager_v2/**/*.so', + 'runtime/kv_cache_manager_v2/*.pyi', + 'runtime/kv_cache_manager_v2/**/*.pyi', + 'runtime/kv_cache_manager_v2/rawref/*.py', + 'runtime/kv_cache_manager_v2/rawref/*.pyi', + 'runtime/*__mypyc*.so', ] package_data += [ @@ -268,10 +290,16 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str], # Skip .py files EXCEPT for generated C++ extension wrappers # (deep_gemm, deep_ep, flash_mla Python files are generated during build) if file.filename.endswith(".py"): - allowed_dirs = ("tensorrt_llm/deep_gemm/", - "tensorrt_llm/deep_ep/", - "tensorrt_llm/flash_mla/") + allowed_dirs = ( + "tensorrt_llm/deep_gemm/", "tensorrt_llm/deep_ep/", + "tensorrt_llm/flash_mla/", + "tensorrt_llm/runtime/kv_cache_manager_v2/rawref/__init__.py" + ) if not any(file.filename.startswith(d) for d in allowed_dirs): + # Exclude all .py files in kv_cache_manager_v2 except rawref/__init__.py + if file.filename.startswith("tensorrt_llm/runtime/kv_cache_manager_v2/") and \ + not file.filename.endswith("rawref/__init__.py"): + continue continue for filename_pattern in package_data: @@ -305,6 +333,38 @@ sanity_check() with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() + # We use find_packages with a custom exclude filter to handle the mypyc compiled modules. + # We want to exclude the .py source files for modules that are compiled to .so. + # We exclude the kv_cache_manager_v2 package entirely from the source list, + # but explicitly add back the rawref subpackage (which is not compiled by mypyc). + # The .so and .pyi files for kv_cache_manager_v2 are added via package_data. +enable_mypyc = os.getenv("TRTLLM_ENABLE_MYPYC", "0") == "1" +if enable_mypyc: + packages = find_packages(exclude=[ + "tensorrt_llm.runtime.kv_cache_manager_v2", + "tensorrt_llm.runtime.kv_cache_manager_v2.*", + ]) + ["tensorrt_llm.runtime.kv_cache_manager_v2.rawref"] + exclude_package_data = { + "tensorrt_llm": [ + "runtime/kv_cache_manager_v2/*.py", + "runtime/kv_cache_manager_v2/**/*.py" + ], + "tensorrt_llm.runtime.kv_cache_manager_v2": ["*.py", "**/*.py"], + } +else: + packages = find_packages() + exclude_package_data = {} + + # Remove mypyc shared objects from package_data to avoid packaging stale files + package_data = [ + p for p in package_data if p not in [ + 'runtime/kv_cache_manager_v2/*.so', + 'runtime/kv_cache_manager_v2/**/*.so', 'runtime/*__mypyc*.so' + ] + ] + # Ensure rawref is included + package_data.append('runtime/kv_cache_manager_v2/rawref/*.so') + # https://setuptools.pypa.io/en/latest/references/keywords.html setup( name='tensorrt_llm', @@ -318,7 +378,8 @@ setup( author="NVIDIA Corporation", url="https://github.com/NVIDIA/TensorRT-LLM", download_url="https://github.com/NVIDIA/TensorRT-LLM/tags", - packages=find_packages(), + packages=packages, + exclude_package_data=exclude_package_data, # TODO Add windows support for python bindings. classifiers=[ "Development Status :: 4 - Beta", diff --git a/tensorrt_llm/runtime/__init__.py b/tensorrt_llm/runtime/__init__.py index d432a152db..d6b57b38e6 100644 --- a/tensorrt_llm/runtime/__init__.py +++ b/tensorrt_llm/runtime/__init__.py @@ -12,6 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + +from dynamic_path_manager import DynamicPathManager + +# Add current directory to sys.path so kv_cache_manager_v2 can be imported as top-level package. +# This is required because when kv_cache_manager_v2 is compiled with mypyc, it is compiled as +# a top-level package (to avoid complex build paths), but at runtime it is used as a submodule. +# The compiled extension might try to import its submodules using absolute imports based on its +# compiled name. +with DynamicPathManager(os.path.dirname(os.path.abspath(__file__)), + clear_cache=False): + import kv_cache_manager_v2 + from .enc_dec_model_runner import EncDecModelRunner from .generation import SamplingConfig # autoflake: skip from .generation import (ChatGLMGenerationSession, GenerationSession, @@ -52,4 +65,5 @@ __all__ = [ 'EncDecModelRunner', 'MultimodalModelRunner', 'PYTHON_BINDINGS', + 'kv_cache_manager_v2', ] diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/Makefile b/tensorrt_llm/runtime/kv_cache_manager_v2/Makefile new file mode 100644 index 0000000000..853ab20d5d --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/Makefile @@ -0,0 +1,51 @@ +# ################################################################################################## +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ################################################################################################## + +# Makefile for kv_cache_manager_v2 +# Replaces build_mypyc.sh and rawref/build.sh + +PYTHON ?= python3 +RUNTIME_DIR ?= .. + +.PHONY: all mypyc rawref clean clean_mypyc clean_rawref + +# Default target +all: rawref mypyc + +# Build the rawref C extension +rawref: + cd rawref && $(PYTHON) setup.py build_ext --inplace + +# Build the mypyc extension +# Must be run from the parent directory (runtime) as setup_mypyc.py expects +# kv_cache_manager_v2/ prefixes in module names. +mypyc: + cd $(RUNTIME_DIR) && $(PYTHON) kv_cache_manager_v2/setup_mypyc.py build_ext --inplace + +# Clean everything +clean: clean_rawref clean_mypyc + +# Clean rawref build artifacts +clean_rawref: + cd rawref && rm -rf build + find rawref -name "*.so" -type f -delete + +# Clean mypyc build artifacts +# Cleans build/ directory in runtime (parent) and .so files in this directory +clean_mypyc: + rm -rf $(RUNTIME_DIR)/build + find . -name "*.so" -type f ! -path "./rawref/*" -delete diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.py b/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.py new file mode 100644 index 0000000000..883c92be47 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import rawref +from ._block_radix_tree import gen_multi_modal_tokens +from ._common import ( + NDEBUG, + CacheLevel, + CacheTier, + CudaStream, + LayerId, + MemAddress, + Priority, + TokenId, + TokenIdExt, +) +from ._config import ( + AttentionLayerConfig, + BufferConfig, + CacheTierConfig, + DataRole, + DiskCacheTierConfig, + GpuCacheTierConfig, + HostCacheTierConfig, + KVCacheManagerConfig, +) +from ._core import BeamIndex, KVCacheManager, _KVCache +from ._life_cycle_registry import LayerGroupId, LifeCycleId + +__all__ = [ + "LifeCycleId", + "LayerGroupId", + "TokenId", + "TokenIdExt", + "KVCacheManager", + "_KVCache", + "BeamIndex", + "LayerId", + "Priority", + "CacheLevel", + "CacheTier", + "CudaStream", + "MemAddress", + "NDEBUG", + "KVCacheManagerConfig", + "AttentionLayerConfig", + "BufferConfig", + "DataRole", + "DiskCacheTierConfig", + "GpuCacheTierConfig", + "HostCacheTierConfig", + "CacheTierConfig", + "gen_multi_modal_tokens", + "rawref", +] diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi b/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi new file mode 100644 index 0000000000..7183561915 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import array +import enum +from dataclasses import dataclass +from typing import ( + Any, + Callable, + ClassVar, + Final, + Iterable, + Iterator, + NewType, + Protocol, + Sequence, + Type, + TypeAlias, + Union, +) + +# From _common.py +NDEBUG: Final[int] + +class CacheTier(enum.IntEnum): + GPU_MEM = 0 + HOST_MEM = 1 + DISK = 2 + +LifeCycleId = NewType("LifeCycleId", int) +LayerGroupId: TypeAlias = LifeCycleId +CacheLevel = NewType("CacheLevel", int) +TokenId = NewType("TokenId", int) +TokenIdExt = Union[TokenId, bytes] +LayerId = NewType("LayerId", int) +CudaStream = NewType("CudaStream", int) +BeamIndex = NewType("BeamIndex", int) +MemAddress = NewType("MemAddress", int) +Priority = NewType("Priority", int) + +# From _config.py +DataRole = NewType("DataRole", str) + +class CacheTierConfig(Protocol): + quota: int + @property + def tier(self) -> CacheTier: ... + def assert_valid(self) -> None: ... + +@dataclass(slots=True) +class GpuCacheTierConfig: + quota: int + @property + def tier(self) -> CacheTier: ... + def assert_valid(self) -> None: ... + +@dataclass(slots=True) +class HostCacheTierConfig: + quota: int + @property + def tier(self) -> CacheTier: ... + def assert_valid(self) -> None: ... + +@dataclass(slots=True) +class DiskCacheTierConfig: + quota: int + path: str + @property + def tier(self) -> CacheTier: ... + def assert_valid(self) -> None: ... + +@dataclass(slots=True) +class BufferConfig: + role: DataRole + size: int + +@dataclass(slots=True) +class HelixConfig: + helix_group_size: int + helix_gpu_rank: int + helix_shard_size: int + shared_comm_port: int + +@dataclass(slots=True) +class AttentionLayerConfig: + layer_id: LayerId + buffers: list[BufferConfig] + sliding_window_size: int | None = None + num_sink_tokens: int | None = None + @property + def window_size(self) -> int | None: ... + +@dataclass(slots=True) +class KVCacheManagerConfig: + tokens_per_block: int + vocab_size: int + cache_tiers: list[CacheTierConfig] + layers: list[AttentionLayerConfig] + max_util_for_resume: float = ... + helix_config: HelixConfig | None = None + +# From _block_radix_tree.py +def gen_multi_modal_tokens( + id_offset: int, multi_modal_data_digest: bytes, num_tokens: int +) -> list[TokenIdExt]: ... + +# From _core/_kv_cache.py +class _Status(enum.Enum): + ACTIVE = enum.auto() + SUSPENDED = enum.auto() + CLOSED = enum.auto() + +IndexSeq = array.array[int] | memoryview[int] + +class _KVCache: + Status: ClassVar[Type[_Status]] + id: Any + def __init__( + self, + manager: "KVCacheManager", + lora_task_id: int | None, + input_tokens: Sequence[TokenIdExt] | None, + id: Any, + custom_priority_callback: Callable[[int, Any], Priority], + ) -> None: ... + def set_page_index_buf( + self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None + ) -> None: ... + @property + def manager(self) -> "KVCacheManager": ... + @property + def cuda_stream(self) -> CudaStream: ... + @cuda_stream.setter + def cuda_stream(self, cuda_stream: CudaStream) -> None: ... + @property + def finish_event(self) -> Any: ... + @property + def num_blocks(self) -> int: ... + def close(self) -> None: ... + @property + def beam_width(self) -> BeamIndex: ... + @beam_width.setter + def beam_width(self, beam_width: BeamIndex) -> None: ... + def get_page_indices(self, layer_group_id: int, beam_id: BeamIndex = ...) -> IndexSeq: ... + def get_all_page_indices( + self, beam_id: BeamIndex, buf_ids: Iterable[tuple[LayerId, DataRole]] + ) -> Iterator[IndexSeq]: ... + def resize(self, capacity: int | None, history_length: int | None = None) -> bool: ... + @property + def capacity(self) -> int: ... + @capacity.setter + def capacity(self, capacity: int) -> None: ... + @property + def history_length(self) -> int: ... + @history_length.setter + def history_length(self, history_length: int) -> None: ... + def commit( + self, + accepted_input_tokens: Sequence[TokenIdExt], + beam_search_indices: Sequence[int] | None = None, + ) -> None: ... + @property + def num_committed_tokens(self) -> int: ... + def stop_committing(self) -> None: ... + def suspend(self) -> None: ... + def resume(self, cuda_stream: CudaStream | None = None) -> bool: ... + @property + def status(self) -> _Status: ... + @property + def is_active(self) -> bool: ... + @property + def tokens_per_block(self) -> int: ... + +# From _core/_kv_cache_manager.py +class KVCacheManager: + def __init__(self, config: KVCacheManagerConfig) -> None: ... + def clear_reusable_blocks(self) -> None: ... + def get_mem_pool_base_address(self, layer_id: LayerId, data_role: DataRole) -> MemAddress: ... + def get_page_stride(self, layer_id: LayerId, data_role: DataRole) -> int: ... + def get_page_index_upper_bound(self, layer_id: LayerId, data_role: DataRole) -> int: ... + def create_kv_cache( + self, + lora_task_id: int | None = None, + input_tokens: Sequence[TokenIdExt] | None = None, + id: Any = None, + custom_priority_callback: Callable[[int, Any], Priority] = ..., + ) -> _KVCache: ... + def resize(self, cache_level: CacheLevel, quota: int, best_efforts: bool = False) -> bool: ... + def get_quota(self, cache_level: CacheLevel) -> int: ... + @property + def cache_tier_list(self) -> tuple[CacheTier, ...]: ... + @property + def tokens_per_block(self) -> int: ... + @property + def allow_seq_rebasing(self) -> bool: ... + @property + def enable_partial_match(self) -> bool: ... + def get_layer_group_id(self, layer_id: LayerId) -> int: ... + @property + def layer_grouping(self) -> tuple[tuple[LayerId, ...], ...]: ... + def clamp_max_seq_len_for_mem(self, batch_size: int, model_max_seq_len: int) -> int: ... diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_block_radix_tree.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_block_radix_tree.py new file mode 100644 index 0000000000..9b974706ce --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_block_radix_tree.py @@ -0,0 +1,437 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import hashlib +from typing import TYPE_CHECKING, Iterator, Sequence, TypeVar, cast + +from . import rawref +from ._common import NDEBUG, BlockOrdinal, PageStatus, TokenId, TokenIdExt +from ._life_cycle_registry import LifeCycle, LifeCycleId, LifeCycleRegistry +from ._utils import TypedIndexList, chunked, filled_list, unwrap_rawref + +if TYPE_CHECKING: + from ._page import CommittedPage + +BlockKey = bytes + + +# id_offset is usually vocab_size +def gen_multi_modal_tokens( + id_offset: int, multi_modal_data_digest: bytes, num_tokens: int +) -> list[TokenIdExt]: + assert num_tokens > 0 + # Alternatively, we could also use (multi_modal_data_digest + i.to_bytes(8, 'little')) or its hash + # digest as token id. + # The implementation below is faster and also works because KV cache reuse of a token is with a + # precondition that all previous tokens also match. So only the first multi-modal token id needs to + # be unique. + return [ + multi_modal_data_digest if i == 0 else TokenId(id_offset + i) for i in range(num_tokens) + ] + + +class Hasher: + __slots__ = "_hasher" + _hasher: "hashlib._Hash" + + def __init__(self, data: int | bytes | None | Sequence[int | bytes] = None) -> None: + self._hasher = hashlib.sha256() + if data is not None: + self.update(data) + + # This function is perf-critical. Expect compromised code quality. + def update(self, data: int | bytes | Sequence[int | bytes]) -> "Hasher": + if type(data) is int: + assert NDEBUG or (data >= 0 and data < (1 << 64)) + self._hasher.update(data.to_bytes(8, "little")) + elif type(data) is bytes: + self._hasher.update(data) + else: + assert isinstance(data, Sequence) + for item in data: + assert ( + NDEBUG or (type(item) is int and (0 <= item < (1 << 64))) or type(item) is bytes + ) + self._hasher.update(item.to_bytes(8, "little") if isinstance(item, int) else item) + return self + + @property + def digest(self) -> bytes: + return self._hasher.digest() + + +TokenBlock = list[TokenIdExt] + + +def sequence_to_blockchain_keys( + tokens_per_block: int, lora_task_id: int | None, tokens: Sequence[TokenIdExt] +) -> Iterator[tuple[TokenBlock, BlockKey]]: + digest = Hasher(lora_task_id).digest + yield [], digest + for token_block in chunked(tokens, tokens_per_block): + digest = Hasher(digest).update(token_block).digest + yield token_block, digest + + +Child = TypeVar("Child", bound="Block | RootBlock") +Children = dict[BlockKey, Child] + + +def get_tree(block: "RootBlock | Block") -> "BlockRadixTree": + node = block + while not isinstance(node, BlockRadixTree): + node = node.prev + return node + + +def remove_subtree(root: "RootBlock | Block") -> list[rawref.ref["CommittedPage"]]: + # taking O(1) space + # remove leaf blocks one by one, in post-order + ret: list[rawref.ref["CommittedPage"]] = [] + block: "RootBlock | Block" = root + while True: + if block.next: + block = next(iter(block.next.values())) + else: + if isinstance(block, Block): + ret.extend(p for p in block.storage if p is not None) + block.storage = filled_list(None, block.num_life_cycles) + assert isinstance(block, RootBlock) or all(page is None for page in block.storage), ( + "Storage is not cleared, yet" + ) + if block._prev() is None: + assert block is root + break + prev_block: Block | RootBlock | BlockRadixTree = block.prev + # Because Block.__del__() may remove RootBlock from BlockRadixTree, we need to check here. + # It may not be in prev_block.next when block is RootBlock. + if block.key in prev_block.next: + prev_block.next.pop(block.key) + if block is root: + break + assert not isinstance(prev_block, BlockRadixTree) + block = prev_block + return ret + + +def traverse_post_order(root: "Block") -> Iterator["Block"]: + "post-order traversal of the subtree rooted at root" + stack: list[Iterator[Block]] = [] + block: Block | None = root + while True: + assert block is not None + if block.next: + child_iter = iter(block.next.values()) + stack.append(child_iter) + block = next(child_iter) + else: + yield (last_yielded := block) + while stack and (block := next(stack[-1], None)) is None: + yield (last_yielded := cast(Block, last_yielded.prev)) + stack.pop() + if not stack: + break + + +def find_best_partial_match_in_next_nodes( + block: "Block | RootBlock", tokens: TokenBlock +) -> tuple["Block | None", int]: + """ + Among all child nodes (self.next), finds the one whose tokens have the longest leading match with the given tokens. + Returns a tuple of (best_block, num_matched_tokens). + If no child matches any tokens, returns (None, 0). + """ + if len(block.next) >= 32: + # TODO: build a database to accelerate partial matching. (TRTLLM-7784) + # For now, it might be too slow to iterate over all children, so let's just skip. + return None, 0 + best_block = None + best_match_len = 0 + for b in block.next.values(): + match_len = b._partial_match_this_node(tokens) + if match_len > best_match_len: + best_match_len = match_len + best_block = b + return best_block, best_match_len + + +class DuplicateKeyError(Exception): + "Another block with the same key already exists" + + key: BlockKey + + def __init__(self, key: BlockKey) -> None: + super().__init__(f"Block with key {key.hex()} already exists") + self.key = key + + +class UselessBlockError(Exception): + block: "Block" + + def __init__(self, block: "Block") -> None: + super().__init__( + f"Block is useless because all its tokens are covered by another block with key = {block.key.hex()}" + ) + self.block = block + + +def _add_or_get_existing( + parent: "RootBlock | Block", tokens: Sequence[TokenIdExt] +) -> "Block | None": + try: + return Block(tokens, parent) + except DuplicateKeyError as e: + return parent.next[e.key] + except UselessBlockError: + return None + + +class RootBlock: + __slots__ = ("_prev", "key", "next", "lora_task_id", "__rawref__") + key: BlockKey + lora_task_id: int | None + _prev: rawref.ref["BlockRadixTree"] + next: Children["Block"] + __rawref__: rawref.ref["RootBlock"] + + def __init__(self, lora_task_id: int | None, prev: "BlockRadixTree") -> None: + self.key = self.make_key(lora_task_id) + assert self.key not in prev.next, "Root block already exists" + self.lora_task_id = lora_task_id + self._prev = rawref.ref(prev) + self.next = {} + self.__rawref__ = rawref.NULL + prev.next[self.key] = self + + def __del__(self) -> None: + self.__rawref__.invalidate() + + @property + def ordinal(self) -> BlockOrdinal: + return BlockOrdinal(-1) + + @property + def prev(self) -> "BlockRadixTree": + return unwrap_rawref(self._prev) + + @property + def num_life_cycles(self) -> LifeCycleId: + return self.prev.num_life_cycles + + @property + def tokens_per_block(self) -> int: + return self.prev.tokens_per_block + + @staticmethod + def make_key(lora_task_id: int | None) -> BlockKey: + return Hasher(lora_task_id).digest + + +class Block: + """ + A block of tokens. Manages data for all layers. + """ + + __slots__ = ("key", "tokens", "ordinal", "_prev", "next", "storage", "__rawref__") + key: BlockKey + tokens: Sequence[TokenIdExt] + ordinal: BlockOrdinal + _prev: rawref.ref["Block | RootBlock"] + next: Children["Block"] + __rawref__: rawref.ref["Block"] + + # indexed with LifeCycleId + storage: TypedIndexList[LifeCycleId, rawref.ref["CommittedPage"] | None] + + @staticmethod + def make_key(prev_key: BlockKey, tokens: Sequence[TokenIdExt]) -> BlockKey: + return Hasher(prev_key).update(tokens).digest + + def __init__(self, tokens: Sequence[TokenIdExt], prev: "Block | RootBlock") -> None: + assert prev.tokens_per_block == prev.prev.tokens_per_block, "prev must be a full block" + self.key = self.make_key(prev.key, tokens) + self.tokens = tokens + self.ordinal = BlockOrdinal(prev.ordinal + 1) + self._prev = rawref.ref(prev) + self.next = {} + self.storage = filled_list(None, prev.num_life_cycles) + self.__rawref__ = rawref.NULL + # a Block is useless if all its tokens are covered by a sibling block. Raise UselessBlockError if so. + if self.key in prev.next: + raise UselessBlockError(prev.next[self.key]) + if len(tokens) < self.tokens_per_block: + # @TODO: when we have the database for find_best_partial_match_in_next_nodes, we may use + # that for faster check. + for b in prev.next.values(): + if b.tokens[: len(tokens)] == tokens: + raise UselessBlockError(b) + # If there are sibling blocks fully covered by this block, remove them. + to_remove = [] + for k, b in prev.next.items(): + if len(b.tokens) < len(tokens) and tokens[: len(b.tokens)] == b.tokens: + assert NDEBUG or (not b.is_full and b is not self and b.key == k and not b.next) + to_remove.append(k) + for k in to_remove: + b = prev.next.pop(k) + assert b.is_orphan # _KVCache may still hold it. + # prev.next keeps a strong ref to this _Block, so no need to remove self from prev.next in __del__(). + prev.next[self.key] = self + + def __del__(self) -> None: + for ref in self.storage: + if ref is not None and ref() is not None: + page = unwrap_rawref(ref) + if page.status == PageStatus.DROPPABLE: + if page.scheduled_for_eviction: + page.manager.exclude_from_eviction(page) + if self._prev() is not None and isinstance(self.prev, RootBlock) and not self.prev.next: + self.prev.prev.next.pop(self.prev.key) + self.__rawref__.invalidate() + + def _partial_match_this_node(self, tokens: TokenBlock) -> int: + """ + Returns the number of leading tokens that match between the given tokens and this block's tokens. + """ + for i, (a, b) in enumerate(zip(tokens, self.tokens)): + if a != b: + return i + return min(len(tokens), len(self.tokens)) + + @property + def num_life_cycles(self) -> LifeCycleId: + return LifeCycleId(len(self.storage)) + + @property + def prev(self) -> "Block | RootBlock": + return unwrap_rawref(self._prev) + + def unset_page(self, lc_idx: LifeCycleId, lc: LifeCycle) -> None: + if self.storage[lc_idx] is None: + return + ordinal = self.ordinal + self.storage[lc_idx] = None + if lc.window_size is None or ordinal < lc.num_sink_blocks: + pages = remove_subtree(self) + for r in pages: + if r() is not None: + page = unwrap_rawref(r) + assert page.status == PageStatus.DROPPABLE + if page.scheduled_for_eviction: + page.manager.exclude_from_eviction(page) + # It's possible to implement more sophisticated logic to remove useless blocks for SWA, e.g. + # check if consecutive available blocks is sufficient for window_size. (TRTLLM-8802) + # But for simplicity, we leave it for now. + curr = self + while ( + (isinstance(curr, Block) and curr.storage[lc_idx] is None) + and not curr.next + and curr._prev() is not None + ): + if curr.key in curr.prev.next: + curr.prev.next.pop(curr.key) + curr = curr.prev + + @property + def tokens_per_block(self) -> int: + # we assume non-leaf blocks are always full. + prev = self.prev + return prev.tokens_per_block if isinstance(prev, RootBlock) else len(prev.tokens) + + @property + def is_full(self) -> bool: + return len(self.tokens) == self.tokens_per_block + + @property + def is_orphan(self) -> bool: + return self.key not in self.prev.next or self.prev.next[self.key] is not self + + +class BlockRadixTree: + __slots__ = ("_life_cycles", "_tokens_per_block", "next", "__rawref__") + _life_cycles: LifeCycleRegistry + _tokens_per_block: int + next: Children[RootBlock] + __rawref__: rawref.ref["BlockRadixTree"] + + def __init__(self, life_cycles: LifeCycleRegistry, tokens_per_block: int) -> None: + self._life_cycles = life_cycles + self._tokens_per_block = tokens_per_block + self.next = {} + self.__rawref__ = rawref.NULL + + def __del__(self) -> None: + self.__rawref__.invalidate() + + def add_or_get_existing(self, lora_task_id: int | None) -> RootBlock: + key = RootBlock.make_key(lora_task_id) + if key in self.next: + return self.next[key] + return RootBlock(lora_task_id, self) + + @property + def tokens_per_block(self) -> int: + return self._tokens_per_block + + @property + def life_cycles(self) -> LifeCycleRegistry: + return self._life_cycles + + @property + def num_life_cycles(self) -> LifeCycleId: + return self.life_cycles.size + + def clear(self) -> list[rawref.ref["CommittedPage"]]: + # taking O(1) space + # remove leaf blocks one by one, in post-order + ret: list[rawref.ref["CommittedPage"]] = [] + while self.next: + block = next(iter(self.next.values())) + ret.extend(remove_subtree(block)) + assert not self.next + return ret + + # yields tuples of (block, num_matched_tokens). num_matched_tokens should be equal to + # tokens_per_block except the last one. + def match( + self, + lora_task_id: int | None, + tokens: Sequence[TokenIdExt], + enable_partial_match: bool = False, + ) -> Iterator[tuple[Block, int]]: + block: Block | RootBlock | BlockRadixTree = self + mismatched_token_block: TokenBlock = [] + for token_block, key in sequence_to_blockchain_keys( + self._tokens_per_block, lora_task_id, tokens + ): + if key in block.next: + block = block.next[key] + if token_block: + assert isinstance(block, Block) + yield block, len(token_block) + else: + mismatched_token_block = token_block + break + if mismatched_token_block and enable_partial_match: + partial_block, match_len = find_best_partial_match_in_next_nodes( + cast(Block | RootBlock, block), mismatched_token_block + ) + if partial_block is not None: + block = partial_block + yield block, match_len + + def _check_sanity(self) -> bool: + raise NotImplementedError( + "[KVCacheManager] Check if there are any unusable blocks that should have been removed." + ) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_common.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_common.py new file mode 100644 index 0000000000..320d8e4a54 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_common.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import os +from dataclasses import dataclass +from typing import Final, NewType + +NDEBUG: Final[int] = int(os.environ.get("TLLM_KV_CACHE_MANAGER_V2_DEBUG", "0")) == 0 + + +class PageStatus(enum.IntEnum): + LOCKED = 0 # Required in GPU. Eviction/dropping not allowed + HELD = 1 # Allow eviction but not dropping + DROPPABLE = 2 # Allow eviction and dropping + + +# Can extend to more tiers in the future, e.g. object storage like AWS S3. +class CacheTier(enum.IntEnum): + GPU_MEM = 0 + HOST_MEM = 1 + DISK = 2 + + +CacheLevel = NewType("CacheLevel", int) + +GPU_LEVEL: Final[CacheLevel] = CacheLevel(0) + +# Normal token id that falls in the tokenizer vocabulary. +TokenId = NewType("TokenId", int) + +# For multi-modal tokens, we can handle it in either of the following ways: +# 1. Hash combine image digest and local_token_id, then use digest for every multi-modal token. +# 2. Use digest only for the first multi-modal token, and use int(vocab_size + local_token_id) for the rest. +# 3. Hash the multi-modal token embedding data and use the digest as TokenIdExt for every multi-modal token. +# If we do this, we can't skip the encoder. +TokenIdExt = TokenId | bytes + +BlockOrdinal = NewType("BlockOrdinal", int) +BlockOrdinalT = type(BlockOrdinal(0)) + +LayerId = NewType("LayerId", int) + +CudaStream = NewType("CudaStream", int) + +BeamIndex = NewType("BeamIndex", int) + +UserId = NewType("UserId", int) + +MemAddress = NewType("MemAddress", int) + +FileDescriptor = NewType("FileDescriptor", int) + +BAD_FILE_DESCRIPTOR: Final[FileDescriptor] = FileDescriptor(-1) + +PageIndex = NewType("PageIndex", int) +BAD_PAGE_INDEX: Final[PageIndex] = PageIndex(-1) + + +@dataclass(slots=True, frozen=True) +class DiskAddress: + fd: FileDescriptor + pos: int + + +Address = MemAddress | DiskAddress + +SlidingWindowSize = int | None + +Priority = NewType("Priority", int) +PRIORITY_MIN: Final[Priority] = Priority(0) +PRIORITY_MAX: Final[Priority] = Priority(100) +PRIORITY_DEFAULT: Final[Priority] = Priority(35) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_config.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_config.py new file mode 100644 index 0000000000..02716cfe58 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_config.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Currently, our nvfp4 kernels require that KV data and its corresponding KV block scale use the same +# block index, but different base address. +# As the ratio between KV data size and KV block scale size is fixed, we can simply use a pool with +# smaller block size and the same number of blocks for block scale. +import os +from dataclasses import dataclass, field +from typing import NewType, Protocol + +from ._common import CacheTier, LayerId + +# The data role of a buffer inside one layer. +# Must be unique for each buffer inside a layer. +# Examples: "key", "value", "key_block_quant", "value_block_quant". +DataRole = NewType("DataRole", str) + + +class CacheTierConfig(Protocol): + """Protocol for cache tier configuration.""" + + quota: int # in bytes + + @property + def tier(self) -> CacheTier: ... + + def assert_valid(self) -> None: ... + + +@dataclass(slots=True) +class GpuCacheTierConfig: + quota: int # in bytes + + @property + def tier(self) -> CacheTier: + return CacheTier.GPU_MEM + + def assert_valid(self) -> None: + assert self.quota > 0, "Quota must be positive" + + +@dataclass(slots=True) +class HostCacheTierConfig: + quota: int # in bytes + + @property + def tier(self) -> CacheTier: + return CacheTier.HOST_MEM + + def assert_valid(self) -> None: + assert self.quota > 0, "Quota must be positive" + + +@dataclass(slots=True) +class DiskCacheTierConfig: + quota: int # in bytes + path: str # a folder where we will store data as files + + @property + def tier(self) -> CacheTier: + return CacheTier.DISK + + def assert_valid(self) -> None: + assert self.quota > 0, "Quota must be positive" + assert os.path.isdir(self.path), ( + f"Disk path {self.path} does not exist or is not a directory" + ) + + +@dataclass(slots=True) +class BufferConfig: + role: DataRole + size: int + + +@dataclass(slots=True) +class AttentionLayerConfig: + layer_id: LayerId + # Each page can have multiple sub-pages, e.g. separate K and V data, block quantization scales for K and/or V, etc. + # KV cache manager will automatically group sub-pages of the same size, and redirect pages of different sizes to + # different memory pools + + # BufferConfig.role should not duplicate + buffers: list[BufferConfig] + # Note that we use None to represent "no sliding window". Sink tokens are excluded. + sliding_window_size: int | None = None + num_sink_tokens: int | None = None + + @property + def window_size(self) -> int | None: + return self.sliding_window_size + + def __post_init__(self) -> None: + assert len(set(buffer.role for buffer in self.buffers)) == len(self.buffers), ( + "duplicate buffer role" + ) + + +@dataclass(slots=True) +class HelixConfig: + helix_group_size: int + helix_gpu_rank: int + # number of tokens in one helix shard + helix_shard_size: int + # must be the same for all ranks in the same helix group and different for different helix groups. + shared_comm_port: int + + +@dataclass(slots=True) +class KVCacheManagerConfig: + """ + Configuration for the KV cache manager. + """ + + tokens_per_block: int + # if you have p-tuning tokens, include them. Only needed for multi-modal. + vocab_size: int + # cache tiers are sorted from warm to cold. The first one must be GPU memory. + cache_tiers: list[CacheTierConfig] + + # AttentionLayerConfig.layer_id should not duplicate + layers: list[AttentionLayerConfig] + + # When memory utilization is above this threshold, KV cache resuming will fail. This helps + # reserving some memory for KVCache growth and avoids frequent suspend/resume for dynamic batch size. + max_util_for_resume: float = field(default=0.9) + + # unsupported yet + helix_config: HelixConfig | None = field(default=None) + + def __post_init__(self) -> None: + assert self.cache_tiers and self.cache_tiers[0].tier == CacheTier.GPU_MEM + assert len(set(layer.layer_id for layer in self.layers)) == len(self.layers), ( + "duplicate layer id" + ) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_copy_engine.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_copy_engine.py new file mode 100644 index 0000000000..d3a7ef97c9 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_copy_engine.py @@ -0,0 +1,374 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import sys +import threading +from _thread import LockType +from collections.abc import Callable, Iterator +from dataclasses import dataclass + +# avoid importing the whole tensorrt_llm module, which takes time during debugging. +from importlib.util import find_spec +from pathlib import Path +from typing import ClassVar, NamedTuple, Sequence, cast + +import cuda.bindings.driver as drv +from dynamic_path_manager import DynamicPathManager + +from ._common import Address, CacheTier, CudaStream, MemAddress +from ._utils import CachedCudaEvent, HomoTuple, HostMem, _unwrap, div_up, stream_wait_events + +if "tensorrt_llm" in sys.modules: + from tensorrt_llm.bindings.internal.batch_manager.kv_cache_manager_v2_utils import ( # noqa # type: ignore + DiskAddress, + DiskToDiskTask, + DiskToHostTask, + HostToDiskTask, + MemToMemTask, + copy_device_to_device, + copy_device_to_host, + copy_disk_to_disk, + copy_disk_to_host, + copy_host_to_device, + copy_host_to_disk, + copy_host_to_host, + ) +else: + # fast path for dev, avoids importing the whole tensorrt_llm module + spec = find_spec("kv_cache_manager_v2") + assert spec is not None and spec.origin is not None + with DynamicPathManager(str(Path(spec.origin).parent.parent.parent), clear_cache=False): + from bindings.internal.batch_manager.kv_cache_manager_v2_utils import ( # noqa + DiskAddress, + DiskToDiskTask, + DiskToHostTask, + HostToDiskTask, + MemToMemTask, + copy_device_to_device, + copy_device_to_host, + copy_disk_to_disk, + copy_disk_to_host, + copy_host_to_device, + copy_host_to_disk, + copy_host_to_host, + ) + + +class CopyTask(NamedTuple): + dst: Address + src: Address + + +def _copy_gpu_to_gpu(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream): + _unwrap( + drv.CUresult( + copy_device_to_device([MemToMemTask(dst, src) for dst, src in tasks], num_bytes, stream) + ) + ) + + +def _copy_host_to_host(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream): + _unwrap( + drv.CUresult( + copy_host_to_host([MemToMemTask(dst, src) for dst, src in tasks], num_bytes, stream) + ) + ) + + +def _copy_disk_to_disk(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream): + _unwrap( + drv.CUresult( + copy_disk_to_disk( + [ + DiskToDiskTask( + DiskAddress( + cast(DiskAddress, dst).fd, + cast(DiskAddress, dst).pos, + ), + DiskAddress( + cast(DiskAddress, src).fd, + cast(DiskAddress, src).pos, + ), + ) + for dst, src in tasks + ], + num_bytes, + stream, + ) + ) + ) + + +def _copy_gpu_to_host(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream): + _unwrap( + drv.CUresult( + copy_device_to_host([MemToMemTask(dst, src) for dst, src in tasks], num_bytes, stream) + ) + ) + + +def _copy_host_to_gpu(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream): + _unwrap( + drv.CUresult( + copy_host_to_device([MemToMemTask(dst, src) for dst, src in tasks], num_bytes, stream) + ) + ) + + +def _copy_disk_to_host(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream): + _unwrap( + drv.CUresult( + copy_disk_to_host( + [ + DiskToHostTask( + cast(MemAddress, dst), + DiskAddress(cast(DiskAddress, src).fd, cast(DiskAddress, src).pos), + ) + for dst, src in tasks + ], + num_bytes, + stream, + ) + ) + ) + + +def _copy_host_to_disk(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream): + _unwrap( + drv.CUresult( + copy_host_to_disk( + [ + HostToDiskTask( + DiskAddress( + cast(DiskAddress, dst).fd, + cast(DiskAddress, dst).pos, + ), + cast(MemAddress, src), + ) + for dst, src in tasks + ], + num_bytes, + stream, + ) + ) + ) + + +Copier = Callable[[Sequence[CopyTask], int, CudaStream], None] + + +def get_copier(dst: CacheTier, src: CacheTier) -> Copier | HomoTuple[Copier]: + copiers: HomoTuple[HomoTuple[Copier | HomoTuple[Copier]]] = ( + # dst = GPU_MEM + ( + _copy_gpu_to_gpu, # src = GPU_MEM + _copy_host_to_gpu, # src = HOST_MEM + (_copy_disk_to_host, _copy_host_to_gpu), # src = DISK + ), + # dst = HOST_MEM + ( + _copy_gpu_to_host, # src = GPU_MEM + _copy_host_to_host, # src = HOST_MEM + _copy_disk_to_host, # src = DISK + ), + # dst = DISK + ( + (_copy_gpu_to_host, _copy_host_to_disk), # src = GPU_MEM + _copy_host_to_disk, # src = HOST_MEM + _copy_disk_to_disk, # src = DISK + ), + ) + return copiers[dst][src] + + +@dataclass(slots=True) +class GrainMetadata: + mutex: LockType + ready_event: CachedCudaEvent # protects the buffer grain. + + +class StagingBuffer: + __slots__ = ("manager", "min_size", "max_size", "_size", "start_grain", "stream") + manager: "StagingBufferManager" + min_size: int + max_size: int + _size: int + start_grain: int + stream: CudaStream + + def __init__( + self, manager: "StagingBufferManager", min_size: int, max_size: int, stream: CudaStream + ): + self.manager = manager + self.min_size = min_size + self.max_size = max_size + self.stream = stream + + @property + def address(self) -> MemAddress: + return MemAddress(self.manager.buffer.address + self.start_grain * self.manager.GRANULARITY) + + @property + def size(self) -> int: + return self._size + + @property + def num_grains(self) -> int: + return div_up(self._size, self.manager.GRANULARITY) + + @property + def grains(self) -> list[GrainMetadata]: + return self.manager.grains[self.start_grain : self.start_grain + self.num_grains] + + def __enter__(self) -> "StagingBuffer": + manager = self.manager + if self.min_size > manager.size: + raise ValueError(f"Requested min_size {self.min_size} is too large for the manager") + with manager.mutex: + self._size = min(self.max_size, manager._suggest_next_max_size_unsafe()) + self.start_grain = manager.next + manager.next += self.num_grains + assert manager.next <= manager.num_grains + if manager.next == manager.num_grains: + manager.next = 0 + + def lock_and_consume_events() -> Iterator[CachedCudaEvent]: + for grain in self.grains: + grain.mutex.acquire() + yield grain.ready_event + grain.ready_event = CachedCudaEvent.NULL + + stream_wait_events(self.stream, lock_and_consume_events()) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + event = CachedCudaEvent(self.stream) + for grain in reversed(self.grains): + grain.ready_event = event + grain.mutex.release() + + +class StagingBufferManager: + __slots__ = ("mutex", "buffer", "grains", "next") + GRANULARITY: ClassVar[int] = 1 << 20 + + mutex: LockType + buffer: HostMem + grains: list[GrainMetadata] + next: int + + def __init__(self, size: int) -> None: + assert size % self.GRANULARITY == 0 + self.mutex = threading.Lock() + num_grains = size // self.GRANULARITY + self.buffer = HostMem(size) + self.grains = [ + GrainMetadata(threading.Lock(), CachedCudaEvent.NULL) for _ in range(num_grains) + ] + self.next = 0 + + @property + def size(self) -> int: + "Requesting more than this will fail." + assert len(self.grains) * self.GRANULARITY == self.buffer.size + return self.buffer.size + + @property + def num_grains(self) -> int: + return len(self.grains) + + def _suggest_next_max_size_unsafe(self) -> int: + "Requesting more than this may degrade performance. Must be called with self.mutex held." + return self.GRANULARITY * (self.num_grains - self.next) + + # max_size is just a hint, the actual size may be smaller. + def new(self, min_size: int, max_size: int, stream: CudaStream) -> StagingBuffer: + """ + min_size is the min required size. max_size is for best efforts. Your should query the actual + size after entering the context. + """ + return StagingBuffer(self, min_size, max_size, stream) + + +class CopyEngine: + __slots__ = ("_staging_buffer_manager",) + _staging_buffer_manager: StagingBufferManager | None + + def __init__(self) -> None: + self._staging_buffer_manager = None + + def close(self) -> None: + self._staging_buffer_manager = None + + @property + def staging_buffer_manager(self) -> StagingBufferManager: + if self._staging_buffer_manager is None: + self._staging_buffer_manager = StagingBufferManager(64 << 20) + return self._staging_buffer_manager + + # @TODO: Use a dedicated stream for each different Copier, take set[CachedCudaEvent] instead of + # stream, and return a new CachedCudaEvent. + def transfer( + self, + dst_cache_tier: CacheTier, + src_cache_tier: CacheTier, + num_bytes: int, + tasks: Sequence[CopyTask], + stream: CudaStream, + ) -> None: + copier = get_copier(dst_cache_tier, src_cache_tier) + if not isinstance(copier, tuple): + return copier(tasks, num_bytes, stream) + assert len(copier) == 2, "for now, we only support 2 copiers via host memory" + manager = self.staging_buffer_manager + remaining = tasks + while remaining: + with manager.new(num_bytes, num_bytes * len(remaining), stream) as buf: + addr = buf.address + n = buf.size // num_bytes + assert n <= len(remaining) + batch = remaining[:n] + copier[0]( + [ + CopyTask(MemAddress(addr + num_bytes * i), t.src) + for i, t in enumerate(batch) + ], + num_bytes, + buf.stream, + ) + copier[1]( + [ + CopyTask(t.dst, MemAddress(addr + num_bytes * i)) + for i, t in enumerate(batch) + ], + num_bytes, + buf.stream, + ) + remaining = remaining[n:] + + +_copy_engine = CopyEngine() +atexit.register(_copy_engine.close) + + +def batched_copy( + dst_cache_tier: CacheTier, + src_cache_tier: CacheTier, + num_bytes: int, + tasks: Sequence[CopyTask], + stream: CudaStream, +) -> None: + _copy_engine.transfer(dst_cache_tier, src_cache_tier, num_bytes, tasks, stream) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_core/__init__.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/__init__.py new file mode 100644 index 0000000000..51ce18d332 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .._common import BeamIndex +from ._kv_cache import _KVCache +from ._kv_cache_manager import KVCacheManager + +__all__ = ["KVCacheManager", "_KVCache", "BeamIndex"] diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py new file mode 100644 index 0000000000..42aee96b42 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py @@ -0,0 +1,1048 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import array +import enum +from collections.abc import Sequence +from contextlib import contextmanager +from dataclasses import dataclass +from itertools import chain +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Iterator, Type, cast + +from .. import rawref +from .._block_radix_tree import Block, RootBlock, UselessBlockError +from .._common import ( + BAD_PAGE_INDEX, + GPU_LEVEL, + NDEBUG, + BeamIndex, + BlockOrdinal, + BlockOrdinalT, + CacheLevel, + CudaStream, + PageIndex, + Priority, + SlidingWindowSize, + TokenIdExt, +) +from .._copy_engine import CopyTask, batched_copy +from .._exceptions import LogicError, OutOfPagesError +from .._life_cycle_registry import LayerGroupId, LifeCycle, LifeCycleId +from .._page import ( + BatchedLockTarget, + BlockPage, + CommittedPage, + UncommittedPage, + _PageHolder, + _SharedPageLock, + batched_lock_to_gpu, +) +from .._storage._config import BufferId +from .._storage_manager import StorageManager +from .._utils import ( + CachedCudaEvent, + TemporaryCudaStream, + TypedIndexList, + div_up, + expect_type, + filled_list, + find_index, + make_typed, + map_optional, + stream_wait_events, + to_typed, + typed_enumerate, + typed_len, + typed_map, + typed_range, + unwrap_optional, + unwrap_rawref, + value_or, +) + +if TYPE_CHECKING: + from ._kv_cache_manager import KVCacheManager + + +@dataclass(slots=True) +class SeqBlock: + pages: TypedIndexList[BeamIndex, TypedIndexList[LifeCycleId, BlockPage]] + # In rare cases, this may be the only strong reference to this block. Assume it's the last block we + # committed on stop_committing(), and it's partial. At the same time, we have another _KVCache + # generating same tokens plus some additional tokens. The block committed by the other _KVCache will + # fully cover tokens of this block. In that case, we will remove this block from the radix tree. + # Which means `tree_block not in tree_block.prev.next` will be True. + tree_block: Block | None + + @property + def is_committed(self) -> bool: + ret = self.tree_block is not None + assert NDEBUG or not ret or len(self.pages) == 1 + assert ( + NDEBUG + or not ret + or all( + p is None or isinstance(p.page, CommittedPage) + for p in chain.from_iterable(self.pages) + ) + ) + assert ( + NDEBUG + or ret + or all( + p is None or isinstance(p.page, UncommittedPage) + for p in chain.from_iterable(self.pages) + ) + ) + return ret + + def __del__(self) -> None: + self.tree_block = None + self.pages.clear() + + +class _Status(enum.Enum): + ACTIVE = enum.auto() + SUSPENDED = enum.auto() + CLOSED = enum.auto() + + +class _CommitState(enum.Enum): + ALLOWED = enum.auto() + # user did not stop but we can't commit any more due to conflict with other blocks + VIRTUAL_STOP = enum.auto() + # user called stop_committing() or close() + USER_STOP = enum.auto() + + +IndexSeq = array.array | memoryview + + +# The _KVCache holds unique/shared ownership of memory blocks. On deletion, the ownership if destroys +# and KVCacheManager takes control of them. A KV cache maintains three lengths: +# 1. num_committed_tokens: the number of tokens that are finalized, immutable and ready for reuse. +# 2. history_length: a cursor separating history and the space for next input tokens. History tokens +# are defined as tokens without query data for the next inference step. For SWA layers, it decides +# which blocks are out-of-window and can be evicted/dropped. In most cases, you don't need to touch +# history_length as it's automatically bumped by the increase of num_committed_tokens, except a few +# cases: +# a. Beam search where we can't commit tokens generated by the last step. But it still makes sense +# to evict uncommitted pages for SWA layers to save memory. +# b. Disaggregated serving with SWA and the reusable tokens are in the other server. We need to +# reserve space for history. Knowing history_length helps us accurately decide which blocks +# needs to be allocated. Then users only transfer data for what is needed. +# c. Multi-round conversation with chain of thoughts (CoT) and excluding CoT tokens for the next +# round. In this case, users should not commit tokens starting from CoT. Then history_length +# needs to be explicitly bumped. +# 3. capacity: the number of tokens that can be stored in the KV cache. It should include the number +# of both historical tokens and input tokens for the next inference step, no matter if it's prefill, +# chunked prefill or generation without/without speculative decoding. For tree-based speculative +# decoding, the number of input tokens here should be the flatten draft length. For beam search, +# multiple candidate tokens at the same position are counted as one. +# num_committed_tokens <= history_length <= capacity always holds. A newly created KV cache has all +# three lengths equal to the number of reused tokens. +# TODO: in __del__, we should check if committed pages are usable for SWA cases. e.g. all pages are +# dropped except the last one. The last one is not usable. +class _KVCache: + __slots__ = ( + "id", + "_manager", + "_lora_task_id", + "_get_priority", + "_cuda_stream", + "_status", + "_beam_width", + "_capacity", + "_history_length", + "_commit_state", + "_blocks", + "_page_indices", + "_committed_tokens", + "_num_committed_blocks", + "_finish_event", + "_tokens_per_block", + "__rawref__", + ) + + Status: ClassVar[Type[_Status]] = _Status + CommitState: ClassVar[Type[_CommitState]] = _CommitState + + id: Any + _manager: "KVCacheManager" + _lora_task_id: int | None + _get_priority: Callable[[BlockOrdinal, LifeCycle], Priority] + _cuda_stream: CudaStream | None + _status: _Status + _beam_width: BeamIndex + _capacity: int + _history_length: int + _commit_state: _CommitState + + _blocks: TypedIndexList[BlockOrdinal, SeqBlock] + # we maintain _page_indices to accelerate the get_page_indices() API. In principle it can be + # computed on the fly, but that would be slow due to python. + _page_indices: TypedIndexList[BeamIndex, TypedIndexList[LifeCycleId, IndexSeq]] + _committed_tokens: list[TokenIdExt] + # Sometimes we can't commit a block because all its tokens are already covered by another block in + # the radix tree. But it's unsafe to just use the other block because: 1. the data may have numeric + # difference, 2. if our block is a partial block, we can't write to memory of the other blocks. + # Internally, we stop committing from such a block, but still give user an illusion that the block is + # committed. In such cases, _committed_tokens contains what users have fed with commit(), while + # _num_committed_blocks contains the number of blocks that are actually committed. + _num_committed_blocks: BlockOrdinal + # set when switch away from ACTIVE, cleared when switching to ACTIVE. + _finish_event: CachedCudaEvent | None + + _tokens_per_block: int + + def __init__( + self, + manager: "KVCacheManager", + lora_task_id: int | None, + input_tokens: Sequence[TokenIdExt] | None, + id: Any, + custom_priority_callback: Callable[[BlockOrdinal, LifeCycle], Priority], + ): + self.id = id + self._manager = manager + self._lora_task_id = lora_task_id + self._get_priority = custom_priority_callback + self._cuda_stream = None + self._status = self.Status.SUSPENDED + self._beam_width = BeamIndex(1) + self._capacity = 0 + self._history_length = 0 + self._commit_state = self.CommitState.ALLOWED + self._blocks = cast(TypedIndexList, []) + self._page_indices = make_typed( + lambda: make_typed(lambda: array.array("i"), self.manager._storage.num_life_cycles), + self.beam_width, + ) + self._committed_tokens = [] + self._num_committed_blocks = BlockOrdinal(0) + self._finish_event = None + self._tokens_per_block = manager.tokens_per_block + self.__rawref__ = rawref.NULL + if input_tokens is not None: + self._setup_for_reuse(input_tokens) + assert NDEBUG or self._check_sanity() + + def set_page_index_buf( + self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None + ) -> None: + """Set the buffer for page indices, so we directly update indices in user buffer to + avoid user-side copy. This is the zero-copy alternative of get_page_indices()""" + length = self.num_blocks + old_indices = self._page_indices[beam_idx][layer_group_id] + new_indices: IndexSeq + if buf is None: + new_indices = array.array("i", old_indices[:length]) + else: + assert buf.ndim == 1 and buf.format == "i" and len(buf) >= length + buf[:length] = old_indices[:length] + buf[length:] = array.array("i", [BAD_PAGE_INDEX]) * (len(buf) - length) + new_indices = buf + self._page_indices[beam_idx][layer_group_id] = new_indices + + @property + def manager(self) -> "KVCacheManager": + return self._manager + + @property + def cuda_stream(self) -> CudaStream: + return unwrap_optional(self._cuda_stream) + + @cuda_stream.setter + def cuda_stream(self, cuda_stream: CudaStream) -> None: + if self._cuda_stream is not None: + if self.is_active: + CachedCudaEvent(self._cuda_stream).wait_in_stream(cuda_stream) + else: + assert self.status == self.Status.SUSPENDED and self._finish_event is None + self._cuda_stream = cuda_stream + + @property + def finish_event(self) -> CachedCudaEvent: + "Event recorded when switching from active to suspended/closed state. Unavailable when active." + return unwrap_optional(self._finish_event) + + @property + def num_blocks(self) -> int: + return len(self._blocks) + + # destroy ownership of memory blocks, so KV cache manager can decide to evict or drop them. After + # close, uncommitted data in blocks for (beam_index >= beam_width) will be lost. + def close(self) -> None: + assert NDEBUG or self._check_sanity() + if self.status == self.Status.CLOSED: + return + self.stop_committing() + assert NDEBUG or self._check_sanity() + with self._record_event(): + self._clear_blocks() + self._status = self.Status.CLOSED + + def __del__(self) -> None: + self.close() + self.__rawref__.invalidate() + + @property + def beam_width(self) -> BeamIndex: + return self._beam_width + + # beam_width > 1 is only for generation. If decreasing beam_width, uncommitted data in blocks for + # (beam_index >= beam_width) will be lost. + @beam_width.setter + def beam_width(self, beam_width: BeamIndex) -> None: + raise NotImplementedError("Not implemented yet for beam search") + + # Get the indices of memory blocks for each beam. + # Due to constraints of the current kernels, K/V data blocks and the correspondding quant scale blocks + # share the same indices, so the output for DataRole.KEY_DATA and DataRole.KEY_BLOCK_SCALE are the same. + def get_page_indices( + self, layer_group_id: LayerGroupId, beam_id: BeamIndex = BeamIndex(0) + ) -> IndexSeq: + indices = self._page_indices[beam_id][layer_group_id] + assert NDEBUG or all( + v == value_or(r, BAD_PAGE_INDEX) + for v, r in zip(indices, self._get_page_indices_ref(layer_group_id, beam_id)) + ) + return indices + + def get_all_page_indices( + self, beam_id: BeamIndex, buf_ids: Iterable[BufferId] + ) -> Iterator[IndexSeq]: + layer_to_lc_ids = self.manager._storage._layer_to_life_cycle_ids + for layer_id, _ in buf_ids: + lc = layer_to_lc_ids[layer_id] + yield self._page_indices[beam_id][lc] + + # reserve space for next inference. Request new blocks from KVCacheManager if necessary. + # if capacity is increased and beam_width > 1, blocks containing new tokens should be allocated for each beam. + # Decrease of capacity may destroy stale blocks (if not used by other requests). + # Decrease of capacity cannot remove historical or committed tokens. + # History length cannot be decreased. + # Increase of history length may trigger out-of-window block eviction/dropping for SWA layers. + # If we use two separate APIs for capacity and history length, sometimes we will need to increase + # capacity first to maintain capacity >= history_length. But then we may have a middle state (between + # two APIs) where we use more pages than necessary for SWA layers. So we use a single API to avoid + # this. Usually this is a concern only for prefill phase where we create many tokens in one step. For + # other cases, we can just set the capacity and history_length properties instead. + def resize(self, capacity: int | None, history_length: int | None = None) -> bool: + assert self.status == self.Status.ACTIVE + tokens_per_block = self.tokens_per_block + assert div_up(self._capacity, tokens_per_block) == len(self._blocks) + capacity = value_or(capacity, self._capacity) + history_length = value_or(history_length, self._history_length) + if history_length < self._history_length: + raise ValueError("History length cannot be decreased") + if capacity < history_length: + raise ValueError("History length cannot be greater than capacity") + if self._shortcut_set_capacity(capacity) and self._shortcut_set_history_length( + history_length + ): + return True + backup_holders = self._unlock_stale_blocks(history_length) + old_num_blocks = BlockOrdinal(div_up(self._capacity, tokens_per_block)) + new_num_blocks = BlockOrdinal(div_up(capacity, tokens_per_block)) + beam_width = BeamIndex(self.beam_width) + num_life_cycles = self.manager._life_cycles.size + if new_num_blocks < old_num_blocks: + with self._record_event(): + del self._blocks[new_num_blocks:] + for beam_indices in self._page_indices: + for indices in beam_indices: + assert all(i == BAD_PAGE_INDEX for i in indices[new_num_blocks:]) + if type(indices) is array.array: + del indices[new_num_blocks:] + else: + indices[new_num_blocks:] = array.array("i", [BAD_PAGE_INDEX]) * ( + len(indices) - new_num_blocks + ) + elif new_num_blocks > old_num_blocks: + num_new_slots = filled_list(0, num_life_cycles) + stale_ranges = [ + _KVCache._get_stale_range(tokens_per_block, history_length, lc) + for _, lc in self.manager._life_cycles.items() + ] + for lc in typed_range(num_life_cycles): + stale_beg, stale_end = stale_ranges[lc] + if old_num_blocks < stale_beg: + assert new_num_blocks >= stale_end + num_new_blocks = (stale_beg - old_num_blocks) + (new_num_blocks - stale_end) + else: + num_new_blocks = new_num_blocks - max(stale_end, old_num_blocks) + num_new_slots[lc] = num_new_blocks * beam_width + try: + slots = self._storage.new_gpu_slots(num_new_slots) + except OutOfPagesError: + self._lock_held_blocks(backup_holders) + return False + for beam_indices in self._page_indices: + for indices in beam_indices: + if type(indices) is array.array: + assert len(indices) == old_num_blocks + indices.extend([BAD_PAGE_INDEX] * (new_num_blocks - old_num_blocks)) + else: + assert len(indices) >= new_num_blocks + stream_wait_events( + self.cuda_stream, (s.ready_event for s in chain.from_iterable(slots)) + ) + for ordinal in typed_range(old_num_blocks, new_num_blocks): + block = make_typed( + lambda: filled_list(cast(BlockPage, None), num_life_cycles), beam_width + ) + for beam_index in typed_range(beam_width): + for lc in typed_range(num_life_cycles): + stale_beg, stale_end = stale_ranges[lc] + if stale_beg <= ordinal < stale_end: + continue + slot = slots[lc].pop() + # We have already waited for ready_event of the slots. + block[beam_index][lc] = UncommittedPage( + self, ordinal, lc, GPU_LEVEL, slot, beam_index + ).lock(self, beam_index, ordinal, lc, skip_wait=True) + self._blocks.append(SeqBlock(block, None)) + assert all(len(slots[lc]) == 0 for lc in typed_range(num_life_cycles)) + self._capacity = capacity + self._history_length = history_length + assert NDEBUG or self._check_sanity() + return True + + @property + def capacity(self) -> int: + "Get the current capacity in number of tokens." + return self._capacity + + @capacity.setter + def capacity(self, capacity: int) -> None: + """ + Reserve space for next inference. Capacity cannot be smaller than history length. + Use resize() instead if you need to change both capacity and history length. If you use two + separate APIs, you may have a middle state (between two APIs) where we use more pages than + necessary for SWA layers. + Expect OutOfPagesError exception if there are not enough pages in GPU memory. + """ + success = self.resize(capacity, None) + if not success: + raise OutOfPagesError("Not enough pages in GPU memory") + + @property + def history_length(self) -> int: + """ + Get the current history length in number of tokens. history_length decides how many blocks + needs to be in GPU memory for SWA layers. + """ + return self._history_length + + @history_length.setter + def history_length(self, history_length: int) -> None: + "History length cannot be decreased. Increase may trigger out-of-window block eviction/dropping for SWA layers." + if self._shortcut_set_history_length(history_length): + return + success = self.resize(None, history_length) + assert success + + # notify KV cache manager that we have some finalized/accepted tokens. If a block becomes full, + # also commit the block for reuse. + # In case of beam search, this should be called only with finalized (converged) tokens, and the + # token data must be in the 0th beam. + # We'll destroy memory blocks for other beams if the whole block is full and committed. + # Committed tokens are always history, so history_length will be automatically updated to maintain + # (num_committed_tokens <= history_length). Note that history_length increase may trigger out-of-window + # block eviction/dropping for SWA layers. + # beam_search_indices: indices indicating which candidate to choose for each token. A block with all + # tokens committed will be unified to one memory page and the other memory pages are dropped. Only for + # beam search. + def commit( + self, + accepted_input_tokens: Sequence[TokenIdExt], + beam_search_indices: Sequence[int] | None = None, + ): + if self.beam_width != 1: + raise NotImplementedError("Not implemented yet for beam search") + if not accepted_input_tokens: + return + assert beam_search_indices is None + assert self.status == self.Status.ACTIVE + if self._commit_state == self.CommitState.USER_STOP: + raise LogicError("Cannot commit tokens after stop_committing()") + self._committed_tokens.extend(accepted_input_tokens) + if self._commit_state == self.CommitState.VIRTUAL_STOP: + return + num_committed_blocks = self._num_committed_blocks + new_num_full_blocks = BlockOrdinal(self.num_committed_tokens // self.tokens_per_block) + if new_num_full_blocks > num_committed_blocks: + with self._record_event(): + for ordinal in typed_range(num_committed_blocks, new_num_full_blocks): + self._commit_block(ordinal, False) + if self.history_length < self.num_committed_tokens: + self.history_length = self.num_committed_tokens + + # Note that the tokens may not be ready yet, if the event passed to the past commit() calls are not yet signaled. + @property + def num_committed_tokens(self) -> int: + return len(self._committed_tokens) + + # Users promise to not commit any more tokens. For cases where we shouldn't reuse generated tokens + # (eg. CoT), this helps us drop (instead of evict) out-of-window blocks for SWA layers. + # If there is a uncommitted block containing committed tokens, we will commit the block immediately. + def stop_committing(self) -> None: + assert self.status != self.Status.CLOSED + if self._commit_state == self.CommitState.USER_STOP: + return + assert NDEBUG or self._check_sanity() + if self._commit_state == self.CommitState.VIRTUAL_STOP: + self._commit_state = self.CommitState.USER_STOP + return + assert self._commit_state == self.CommitState.ALLOWED + if self.num_committed_tokens % self.tokens_per_block != 0: + ordinal = _KVCache._to_block_ordinal(self.tokens_per_block, self.num_committed_tokens) + with self._record_event(): + self._commit_block(ordinal, True) + else: + self._commit_state = self.CommitState.USER_STOP + self._on_stop_committing() + # TODO: check if the last committed pages are usable, in case some prior pages are already + # dropped. For SWA, this can be done only when we stop committing. (TRTLLM-8802) + assert self._commit_state == self.CommitState.USER_STOP + + # Suspend, allow the KV cache manager to evict buffers from GPU, but don't drop them. + # suspend+resume allows us to implement dynamic batch size. May also be used to support HSTU model. + def suspend(self) -> None: + assert self.status == self.Status.ACTIVE + assert self._check_sanity() + assert self._finish_event is None + for beam_idx, beam_indices in typed_enumerate(self._page_indices): + for lc, indices in typed_enumerate(beam_indices): + if type(indices) is memoryview: + self.set_page_index_buf(beam_idx, lc, None) + # used by _SharedPageLock.__del__ + with self._record_event(): + for ordinal, beam_idx, lc_idx in self._active_pages(): + beam_block = self._block(ordinal, beam_idx) + holder = expect_type(_SharedPageLock, beam_block[lc_idx]).holder + # after this assignment, __del__ of the original _SharedPageLock will use self.finish_event + # to indicate end of usage for the page. + beam_block[lc_idx] = holder + self._status = self.Status.SUSPENDED + + # Resume, migrate buffers to GPU memory. + def resume(self, cuda_stream: CudaStream | None = None) -> bool: + assert self.status == self.Status.SUSPENDED + utilization = max(self._storage.get_utilization(GPU_LEVEL)) + if utilization > self.manager._init_config.max_util_for_resume: + return False + if cuda_stream is not None: + self.cuda_stream = cuda_stream + assert self._cuda_stream is not None, "cuda_stream is never set" + assert self._finish_event is None + tasks = list[BatchedLockTarget]() + for ordinal, beam_idx, lc_idx in self._active_pages(): + beam_block = self._block(ordinal, beam_idx) + page = expect_type(_PageHolder, beam_block[lc_idx]).page + tasks.append(BatchedLockTarget(page, beam_idx, ordinal, lc_idx)) + try: + locks = batched_lock_to_gpu(self, tasks) + except OutOfPagesError: + return False + for (ordinal, beam_idx, lc_idx), lock in zip(self._active_pages(), locks): + beam_block = self._block(ordinal, beam_idx) + page = expect_type(_PageHolder, beam_block[lc_idx]).page + assert page is lock.page + beam_block[lc_idx] = lock + self._status = self.Status.ACTIVE + return True + + def _active_pages(self) -> Iterator[tuple[BlockOrdinal, BeamIndex, LifeCycleId]]: + for lc_idx, lc in self.manager._life_cycles.items(): + stale_start, stale_end = _KVCache._get_stale_range( + self.tokens_per_block, self.history_length, lc + ) + sink_blocks = typed_range(stale_start) + window_blocks = typed_range(stale_end, typed_len(self._blocks)) + for ordinal in chain(sink_blocks, window_blocks): + block = self._blocks[ordinal] + for beam_idx, _ in typed_enumerate(block.pages): + yield ordinal, beam_idx, lc_idx + + @property + def status(self) -> _Status: + return self._status + + @property + def is_active(self) -> bool: + return self.status == self.Status.ACTIVE + + @property + def tokens_per_block(self) -> int: + return self._tokens_per_block + + def _page( + self, block_ordinal: BlockOrdinal, beam_index: BeamIndex, life_cycle: LifeCycleId + ) -> BlockPage: + return self._blocks[block_ordinal].pages[beam_index][life_cycle] + + def _block( + self, block_ordinal: BlockOrdinal, beam_index: BeamIndex + ) -> TypedIndexList[LifeCycleId, BlockPage]: + return self._blocks[block_ordinal].pages[beam_index] + + def _commit_block(self, ordinal: BlockOrdinal, is_last: bool) -> None: + "Commit the block for reuse. Block must be full of tokens except for the last block." + assert self._commit_state == self.CommitState.ALLOWED + assert ( + ordinal == self._num_committed_blocks or self._commit_state != self.CommitState.ALLOWED + ) + seq_block = self._blocks[ordinal] + assert typed_len(seq_block.pages) == 1, "Must have 1 beam only" + beam_idx = BeamIndex(0) + beam_block = seq_block.pages[beam_idx] + tokens_per_block = self.tokens_per_block + start = ordinal * tokens_per_block + tokens = self._committed_tokens[start : start + tokens_per_block] + num_tokens = len(tokens) + is_full = num_tokens == tokens_per_block + if not is_last and not is_full: + raise LogicError("Cannot commit block that is not full except last block") + prev: RootBlock | Block + if ordinal == 0: + prev = self.manager._radix_tree.add_or_get_existing(self._lora_task_id) + else: + prev = self._get_tree_block(BlockOrdinal(ordinal - 1)) + try: + tree_block = Block(tokens, prev) + is_new = True + except UselessBlockError as e: + tree_block = e.block + assert tree_block.tokens[:num_tokens] == tokens + is_new = False + + assert tree_block.tokens_per_block == tokens_per_block + if is_new: + # We are the only writer to padding. Other _KVCache reusing it should make copies. + uncommitted_pages = self._take_uncommitted_page(ordinal, beam_idx) + # convert uncommitted pages to committed pages and create a new block in the radix tree. + for lc, (page, locked) in typed_enumerate(uncommitted_pages): + if page is None: + continue + p = page.convert_to_committed(tree_block) + tree_block.storage[lc] = rawref.ref(p) + # The page comes from uncommitted page of self, so safe to skip wait. + beam_block[lc] = ( + p.lock(self, beam_idx, ordinal, lc, skip_wait=True) if locked else p.hold() + ) + seq_block.tree_block = tree_block + assert self._get_tree_block(ordinal) is tree_block + self._num_committed_blocks = BlockOrdinal(ordinal + 1) + elif tree_block.is_full and self.manager.allow_seq_rebasing: + # try to replace our pages with pages from the existing block. + reuse_list = list[tuple[LifeCycleId, CommittedPage]]() + for lc in typed_range(typed_len(beam_block)): + if beam_block[lc] is None: + continue + existing_page = map_optional(tree_block.storage[lc], lambda p: p()) + locked = isinstance(beam_block[lc], _SharedPageLock) + if existing_page is None: + # The reusable page is gone. We put our own page into the tree block. + page = cast(UncommittedPage, cast(_SharedPageLock, beam_block[lc]).page) + beam_block[lc] = None + p = page.convert_to_committed(tree_block) + # The page comes from uncommitted page of self, so safe to skip wait. + beam_block[lc] = ( + p.lock(self, beam_idx, ordinal, lc, skip_wait=True) if locked else p.hold() + ) + else: + if locked: + beam_block[lc] = cast(_SharedPageLock, beam_block[lc]).holder + reuse_list.append((lc, existing_page)) + locks = batched_lock_to_gpu( + self, [BatchedLockTarget(p, beam_idx, ordinal, lc) for lc, p in reuse_list] + ) + for (lc, _), lock in zip(reuse_list, locks): + beam_block[lc] = lock + seq_block.tree_block = tree_block + assert self._get_tree_block(ordinal) is tree_block + self._num_committed_blocks = BlockOrdinal(ordinal + 1) + else: + # We can't commit and can't reuse existing block. Just stop committing. + self._commit_state = self.CommitState.VIRTUAL_STOP + + if is_last or self._commit_state == self.CommitState.VIRTUAL_STOP: + self._commit_state = self.CommitState.USER_STOP + self._on_stop_committing() + + def _on_stop_committing(self) -> None: + # If there are stale held uncommitted pages, release them. + # @TODO: add test for this. + for lc_idx, lc in self.manager._life_cycles.items(): + start, end = _KVCache._get_stale_range(self.tokens_per_block, self.history_length, lc) + start = max(start, self._num_committed_blocks) + for ordinal in typed_range(start, end): + block = self._blocks[ordinal] + assert not block.is_committed + for beam_block in block.pages: + assert isinstance(beam_block[lc_idx], _PageHolder) + beam_block[lc_idx] = None + assert NDEBUG or self._check_sanity() + + def _unlock_stale_blocks( + self, new_history_length: int + ) -> list[tuple[BlockOrdinal, BeamIndex, LifeCycleId, _PageHolder]]: + "For SWA layers, unlock out-of-window blocks." + if new_history_length == self.history_length: + return [] + with self._record_event(): + ret = list[tuple[BlockOrdinal, BeamIndex, LifeCycleId, _PageHolder]]() + for lc_idx, lc in self.manager._life_cycles.items(): + if lc.window_size is None: + continue + _, old_end = _KVCache._get_stale_range( + self.tokens_per_block, self.history_length, lc + ) + new_beg, new_end = _KVCache._get_stale_range( + self.tokens_per_block, new_history_length, lc + ) + for ordinal in typed_range( + max(old_end, new_beg), min(typed_len(self._blocks), new_end) + ): + block = self._blocks[ordinal] + is_committed = block.is_committed + hold_for_commit = ( + not is_committed and self._commit_state == self.CommitState.ALLOWED + ) + for beam_idx, beam_block in typed_enumerate(block.pages): + holder = expect_type(_SharedPageLock, beam_block[lc_idx]).holder + ret.append((ordinal, beam_idx, lc_idx, holder)) + beam_block[lc_idx] = holder if hold_for_commit else None + return ret + + def _lock_held_blocks( + self, backup_holders: list[tuple[BlockOrdinal, BeamIndex, LifeCycleId, _PageHolder]] + ): + "Revert _unlock_unused_blocks() by locking the held blocks." + locks = batched_lock_to_gpu( + self, + [ + BatchedLockTarget(holder.page, beam_idx, ordinal, lc) + for ordinal, beam_idx, lc, holder in backup_holders + ], + ) + for lock in locks: + user = lock._user + self._block(user.ordinal, user.beam_index)[user.life_cycle] = lock + + @property + def _storage(self) -> StorageManager: + return self.manager._storage + + @staticmethod + def _to_block_ordinal(tokens_per_block: int, token_ordinal: int) -> BlockOrdinal: + return BlockOrdinal(token_ordinal // tokens_per_block) + + def _get_tree_block(self, ordinal: BlockOrdinal) -> Block: + assert self._blocks[ordinal].is_committed + ret = unwrap_optional(self._blocks[ordinal].tree_block) + if not NDEBUG: + for b in self._block(ordinal, BeamIndex(0)): + assert b is None or (isinstance(b.page, CommittedPage) and b.page.block() is ret) + return ret + + def _take_uncommitted_page( + self, ordinal: BlockOrdinal, beam_idx: BeamIndex + ) -> TypedIndexList[LifeCycleId, tuple[UncommittedPage | None, bool]]: + """ + Take ownership of the uncommitted pages, together with bool flag indicating if it was locked. + And reset holders to None. + """ + holders = self._block(ordinal, beam_idx) + num_life_cycles = self.manager._life_cycles.size + ret: TypedIndexList[LifeCycleId, tuple[UncommittedPage | None, bool]] = filled_list( + (None, False), num_life_cycles + ) + for lc, holder in typed_enumerate(holders): + if holder is None: + continue + assert isinstance(holder.page, UncommittedPage) + locked = isinstance(holder, _SharedPageLock) + ret[lc] = (holder.page, locked) + # When using debugpy with breakpoints on exceptions enabled, the lock/holder is not GC'ed even + # after return from this function. That will likely lead to assertion failures later. + holders[lc] = None + return ret + + def _check_sanity(self) -> bool: + is_closed = self.status == self.Status.CLOSED + if is_closed: + return self.num_blocks == 0 + assert self.num_committed_tokens <= self.history_length <= self.capacity + assert self.num_blocks == div_up(self.capacity, self.tokens_per_block) + + def get_range(lc: LifeCycle): + return _KVCache._get_stale_range(self.tokens_per_block, self.history_length, lc) + + stale_ranges = typed_map(self.manager._life_cycles.get(), get_range) + num_life_cycles = self.manager._life_cycles.size + for ordinal, block in typed_enumerate(self._blocks): + is_committed = ordinal < self._num_committed_blocks + assert is_committed == block.is_committed + for beam_block in block.pages: + assert typed_len(beam_block) == num_life_cycles + for lc in typed_range(num_life_cycles): + holder = beam_block[lc] + start, end = stale_ranges[lc] + if start <= ordinal < end: + if is_committed or self._commit_state != self.CommitState.ALLOWED: + assert holder is None + else: + # For the decoder-side disagg case, for the first step, we will skip the + # out-of-window blocks. + assert isinstance(holder, _PageHolder) or ( + holder is None and not self._committed_tokens + ) + else: + assert isinstance( + holder, (_SharedPageLock if self.is_active else _PageHolder) + ) + if holder is not None: + assert is_committed == isinstance(holder.page, CommittedPage) + return True + + @staticmethod + def _get_stale_range( + tokens_per_block: int, history_length: int, life_cycle: LifeCycle + ) -> tuple[BlockOrdinal, BlockOrdinal]: + """ + Range of the stale blocks. Stale blocks are no longer needed for inference. Stale pages should be + held if we may commit them later, or droppable otherwise. + """ + num_blocks = div_up(history_length, tokens_per_block) + start = BlockOrdinal(min(num_blocks, life_cycle.num_sink_blocks)) + window_size = life_cycle.window_size + if window_size is None: + return start, start + # +1 because the next input token will be in the window as well. + return start, max( + start, _KVCache._to_block_ordinal(tokens_per_block, history_length + 1 - window_size) + ) + + def _setup_for_reuse(self, input_tokens: Sequence[TokenIdExt]) -> None: + manager = self.manager + lora_task_id = self._lora_task_id + matched = list( + manager._radix_tree.match( + lora_task_id, input_tokens or [], manager.enable_partial_match + ) + ) + tokens_per_block = manager.tokens_per_block + assert all(b[1] == tokens_per_block for b in matched[:-1]) + + def get_num_matched_tokens(_): + return tokens_per_block * (len(matched) - 1) + matched[-1][1] if matched else 0 + + life_cycles = manager._life_cycles + + def has_pages(block: Block, lc_list: Iterable[LifeCycleId]) -> bool: + return all(block.storage[lc] is not None for lc in lc_list) + + # check for full attention layers + if any(lc.window_size is None for lc in life_cycles): + lc_list = [lc_idx for lc_idx, lc in life_cycles.items() if lc.window_size is None] + + def check_no_pages(b: tuple[Block, int]): + return not has_pages(b[0], lc_list) + + n = find_index(matched, check_no_pages) + matched = matched[:n] + + def has_page(block: Block, lc: LifeCycleId) -> bool: + return block.storage[lc] is not None + + swa_life_cycles = tuple(lc for lc in life_cycles.items() if lc[1].window_size is not None) + # check for SWA sink + for lc_idx, lc in swa_life_cycles: + + def check_no_page_lc(b: tuple[Block, int]): + return not has_page(b[0], lc_idx) + + n = find_index(matched[: lc.num_sink_blocks], check_no_page_lc) + if n < lc.num_sink_blocks: + matched = matched[:n] + # check for SWA window + num_tokens = 0 + while matched: + num_tokens = get_num_matched_tokens(matched) + for lc_idx, lc in swa_life_cycles: + if lc.window_size is None: + continue + + def check_has_page_lc(b: tuple[Block, int]): + return has_page(b[0], lc_idx) + + n = find_index(reversed(matched), check_has_page_lc) + if n != 0: + matched = matched[:-n] + break + _, stale_end = _KVCache._get_stale_range(tokens_per_block, num_tokens, lc) + + def check_no_page_stale(b: tuple[Block, int]): + return not has_page(b[0], lc_idx) + + n = find_index(reversed(matched[stale_end:]), check_no_page_stale) + if len(matched) - n > stale_end: + matched = matched[: len(matched) - n] + break + else: + break + num_tokens = get_num_matched_tokens(matched) + self._committed_tokens = list(input_tokens[:num_tokens]) + self._history_length = num_tokens + self._capacity = num_tokens + # fill self._blocks + self._blocks = to_typed( + BlockOrdinalT, + [ + SeqBlock( + make_typed( + lambda: filled_list(cast(BlockPage, None), life_cycles.size), + self.beam_width, + ), + b[0] if b[1] == tokens_per_block else None, + ) + for b in matched + ], + ) + + beam_idx = BeamIndex(0) + for lc_idx, lc in life_cycles.items(): + stale_start, stale_end = _KVCache._get_stale_range( + tokens_per_block, get_num_matched_tokens(matched), lc + ) + for ordinal in chain( + typed_range(stale_start), typed_range(stale_end, BlockOrdinal(len(matched))) + ): + block = self._block(ordinal, beam_idx) + holder = unwrap_rawref(unwrap_optional(matched[ordinal][0].storage[lc_idx])).hold() + if matched[ordinal][1] == tokens_per_block: + block[lc_idx] = holder + continue + # make copy for partial blocks. + assert ordinal == len(matched) - 1 and self._blocks[ordinal].tree_block is None + page = holder.page + assert page.manager is manager._storage + storage = manager._storage + num_slots = filled_list(0, life_cycles.size) + num_slots[lc_idx] = 1 + pg_idx = storage.get_pool_group_index(lc_idx) + # try to fine one slot in any cache level + for i in range(manager._storage.num_cache_levels): + lvl = CacheLevel(i + page.cache_level) + try: + slot = storage.new_slots(lvl, num_slots)[lc_idx][0] + except OutOfPagesError: + continue + except Exception: + raise + dst_tier = storage.cache_tiers[lvl] + src_tier = storage.cache_tiers[page.cache_level] + with TemporaryCudaStream((slot.ready_event, page.ready_event)) as stream: + slot_size = storage.slot_size(pg_idx) + for p in typed_range(storage.num_pools(pg_idx)): + dst = storage.slot_address(lvl, pg_idx, slot.slot_id, p) + src = storage.slot_address(page.cache_level, pg_idx, page.slot_id, p) + batched_copy( + dst_tier, src_tier, slot_size[p], [CopyTask(dst, src)], stream.get() + ) + ready_event = stream.take_finish_event() + page.ready_event = ready_event + slot.ready_event = ready_event + block[lc_idx] = UncommittedPage( + self, ordinal, lc_idx, lvl, slot, beam_idx + ).hold() + break # success + else: # failed + self._clear_blocks() + raise RuntimeError( + "We need to copy a block for partial match but we can't find enough pages in " + "any cache level. Did you set up a secondary / third level of cache storage? " + "Do you have too many instances of suspended KV cache? You can also avoid this " + "failure by disallowing partial matching." + ) + self._num_committed_blocks = BlockOrdinal(len(self._committed_tokens) // tokens_per_block) + for beam_indices in self._page_indices: + for indices in beam_indices: + if type(indices) is array.array: + indices.extend([BAD_PAGE_INDEX] * (self.num_blocks - len(indices))) + else: + assert len(indices) >= self.num_blocks + + def _clear_blocks(self) -> None: + # drop the last block first + while self._blocks: + self._blocks.pop() + + @contextmanager + def _record_event(self) -> Iterator[None]: + assert self._finish_event is None + self._finish_event = CachedCudaEvent(self.cuda_stream) + try: + yield + finally: + self._finish_event = None + + def _update_page_index( + self, beam_idx: BeamIndex, ordinal: BlockOrdinal, lc: LifeCycleId, page_index: PageIndex + ) -> PageIndex: + indices = self._page_indices[beam_idx][lc] + old = PageIndex(indices[ordinal]) + indices[ordinal] = page_index + return old + + def _get_page_indices_ref( + self, lc: LifeCycleId, beam_id: BeamIndex = BeamIndex(0) + ) -> Iterator[int | None]: + assert beam_id < self.beam_width + assert self.is_active + pages = ( + map_optional( + b.pages[beam_id][lc] if beam_id < len(b.pages) else None, + lambda h: cast(_PageHolder | _SharedPageLock, h).page, + ) + for b in self._blocks + ) + return self._storage.get_page_indices_ref(lc, pages) + + def _shortcut_set_capacity(self, capacity: int) -> bool: + "Shortcut for cases without side effects. Just for better performance." + tokens_per_block = self.tokens_per_block + if div_up(capacity, tokens_per_block) == div_up(self._capacity, tokens_per_block): + self._capacity = capacity + return True + return False + + def _shortcut_set_history_length(self, history_length: int) -> bool: + "Shortcut for cases without side effects. Just for better performance." + tokens_per_block = self.tokens_per_block + + def no_side_effect(window: SlidingWindowSize): + return window is None or ( + (history_length + 1 - window) // tokens_per_block + == (self._history_length + 1 - window) // tokens_per_block + ) + + if all(no_side_effect(lc.window_size) for lc in self.manager._life_cycles): + self._history_length = history_length + return True + return False diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py new file mode 100644 index 0000000000..12b8282772 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache_manager.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Sequence +from typing import Any, cast + +from .._block_radix_tree import BlockRadixTree +from .._common import ( + GPU_LEVEL, + PRIORITY_DEFAULT, + BlockOrdinal, + CacheLevel, + CacheTier, + LayerId, + MemAddress, + PageStatus, + Priority, + TokenIdExt, +) +from .._config import DataRole, KVCacheManagerConfig +from .._life_cycle_registry import LayerGroupId, LifeCycle, LifeCycleId, LifeCycleRegistry +from .._storage._config import create_storage_config +from .._storage._core import PoolGroupIndex +from .._storage_manager import StorageManager +from .._utils import ( + HomoTuple, + TypedIndexList, + div_up, + exact_div, + filled_list, + init_cuda_once, + typed_enumerate, + typed_range, + unwrap_rawref, +) +from ._kv_cache import _KVCache + + +class KVCacheManager: + __slots__ = ("_init_config", "_life_cycles", "_radix_tree", "_storage") + _init_config: KVCacheManagerConfig + _life_cycles: LifeCycleRegistry + _radix_tree: BlockRadixTree + _storage: StorageManager + + def __init__(self, config: KVCacheManagerConfig) -> None: + init_cuda_once() + self._init_config = config + self._life_cycles = LifeCycleRegistry(config) + self._radix_tree = BlockRadixTree(self._life_cycles, config.tokens_per_block) + storage_config = create_storage_config(config) + self._storage = StorageManager(self._life_cycles, storage_config) + + def __del__(self) -> None: + self.clear_reusable_blocks() + + def clear_reusable_blocks(self) -> None: + for ref in self._radix_tree.clear(): + assert unwrap_rawref(ref).status == PageStatus.DROPPABLE + self._storage.exclude_from_eviction(unwrap_rawref(ref)) + for level in self._storage._levels: + for pg_idx in typed_range(level.storage.num_pool_groups): + assert level.controller.num_evictable_pages(pg_idx) == 0 + + def get_mem_pool_base_address(self, layer_id: LayerId, data_role: DataRole) -> MemAddress: + """ + Get the base address of the memory pool holding pages for the given layer and data role. + It's guaranteed that for one layer, multiple buffers of the same size have the same base address. + """ + return self._storage.get_mem_pool_base_address(layer_id, data_role) + + # Currently always equals to page size. In the future, that will change when kernels support page stride. + def get_page_stride(self, layer_id: LayerId, data_role: DataRole) -> int: + attr = self._storage.get_buffer_attr(layer_id, data_role) + return attr.size + + def get_page_index_upper_bound(self, layer_id: LayerId, data_role: DataRole) -> int: + """ + The upper bound of page indices for the given layer and data role. + Note that this is not the same as the max number of pages available for this layer and data role. + Internally, multiple buffers may share one memory pool. The purpose of this API is just in case + users want to wrap the memory pool as a tensor with known shape. + """ + storage = self._storage + lc_id = storage._layer_to_life_cycle_ids[layer_id] + pg_idx = storage.get_pool_group_index(lc_id) + pool_group = storage._levels[GPU_LEVEL].storage._pool_groups[pg_idx] + num_slots = pool_group.num_slots + attr = storage.get_buffer_attr(layer_id, data_role) + pool_idx = attr.pool_index + slot_size = pool_group.slot_size[pool_idx] + return exact_div(slot_size, attr.size) * num_slots - exact_div(attr.offset, attr.size) + + def create_kv_cache( + self, + lora_task_id: int | None = None, + input_tokens: Sequence[TokenIdExt] | None = None, + id: Any = None, + custom_priority_callback: Callable[[BlockOrdinal, LifeCycle], Priority] = lambda _, + __: PRIORITY_DEFAULT, + ) -> _KVCache: + """ + lora_task_id: match lora_task_id before matching any tokens. + custom_priority_callback: takes block index and layer sliding window size, returns priority. + If priority returned is higher than existing priority for reused blocks, the block priority is updated. + Newly created KV cache is suspended. You need to call resume() with a cuda stream to make it active + & ready in that stream. + Returns None if suspended=False and we don't have enough resource. + This call will attempt to reuse KV cache blocks. + It's user responsibility to remove the last token from prompts if we need to re-compute the token + generated by prefill. + """ + return _KVCache(self, lora_task_id, input_tokens, id, custom_priority_callback) + + def resize(self, cache_level: CacheLevel, quota: int, best_efforts: bool = False) -> bool: + """ + If best_efforts is True, we will try to resize the quota to the largest possible value that is + still <= quota, and returns False only when we cannot resize the quota at all. + If best_efforts is False, we will resize the quota to the exact value of quota, and give up + if not possible. + """ + raise NotImplementedError("Not implemented") + + def get_quota(self, cache_level: CacheLevel) -> int: + return self._storage._levels[cache_level].storage.total_quota + + # sorted by CacheLevel from warm to cold + @property + def cache_tier_list(self) -> HomoTuple[CacheTier]: + return self._storage.cache_tiers + + @property + def tokens_per_block(self) -> int: + return self._radix_tree.tokens_per_block + + @property + def allow_seq_rebasing(self) -> bool: + """ + If True, when we commit a full block, we will try to find a existing reusable block with the + same tokens and reuse that block instead to save some memory. Intra-batch reuse will be enabled + if this is True. + """ + return True + + @property + def enable_partial_match(self) -> bool: + return True + + def get_layer_group_id(self, layer_id: LayerId) -> LayerGroupId: + return self._storage._layer_to_life_cycle_ids[layer_id] + + @property + def layer_grouping(self) -> HomoTuple[HomoTuple[LayerId]]: + layer_to_life_cycle_ids = self._storage._layer_to_life_cycle_ids + num_life_cycles = self._life_cycles.size + grouping = dict[LifeCycleId, list[LayerId]]({i: [] for i in typed_range(num_life_cycles)}) + for layer_id, life_cycle_id in typed_enumerate(layer_to_life_cycle_ids): + grouping[life_cycle_id].append(layer_id) + return tuple(tuple(grouping[i]) for i in typed_range(num_life_cycles)) + + # @TODO: need updating when dynamic resizing is supported. + def clamp_max_seq_len_for_mem(self, batch_size: int, model_max_seq_len: int) -> int: + "Get the max possible sequence length limited by the GPU memory pools." + assert batch_size > 0 + tokens_per_block = self.tokens_per_block + life_cycles = self._life_cycles + storage = self._storage + num_pool_groups = storage.num_pool_groups + remaining_slots = cast( + TypedIndexList[PoolGroupIndex, int], + [storage.num_slots(pg) for pg in typed_range(num_pool_groups)], + ) + lc_to_pg_idx = storage._life_cycle_grouping + + def get_num_slots(seq_len: int) -> TypedIndexList[PoolGroupIndex, int]: + ret = filled_list(0, num_pool_groups) + for lc_id, lc in life_cycles.items(): + stale_range = _KVCache._get_stale_range(tokens_per_block, seq_len, lc) + num_stale_blocks = stale_range[1] - stale_range[0] + num_slots = div_up(seq_len, tokens_per_block) - num_stale_blocks + pg_idx = lc_to_pg_idx[lc_id] + ret[pg_idx] += num_slots + return ret + + for pg in typed_range(num_pool_groups): + remaining_slots[pg] -= get_num_slots(1)[pg] * (batch_size - 1) + assert remaining_slots[pg] >= 0 + + def is_enough(num_blocks: int) -> bool: + return all( + cnt <= rem + for cnt, rem in zip(get_num_slots(num_blocks * tokens_per_block), remaining_slots) + ) + + assert is_enough(1) + lb = 1 + ub = div_up(model_max_seq_len, tokens_per_block) + if is_enough(ub): + return model_max_seq_len + while lb < ub: + mid = (lb + ub) // 2 + if is_enough(mid): + lb = mid + else: + ub = mid - 1 + return min(lb * tokens_per_block, model_max_seq_len) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_cuda_virt_mem.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_cuda_virt_mem.py new file mode 100644 index 0000000000..bfef67ed80 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_cuda_virt_mem.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +import cuda.bindings.driver as drv + +from ._common import MemAddress +from ._utils import ItemHolderWithSharedPool, PooledFactoryBase, _unwrap, div_up + + +# Physical memory +class NativePhysMemAllocator: + __slots__ = ("_device_id", "_size", "_prop", "_outstanding_handles") + + _device_id: int + _size: int + _prop: drv.CUmemAllocationProp + _outstanding_handles: set[int] # allocated but not released + + def __init__(self, size: int) -> None: + self._device_id = int(_unwrap(drv.cuCtxGetDevice())) # pyright: ignore + self._size = size + prop = drv.CUmemAllocationProp() + prop.type = drv.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED + prop.location.type = drv.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + prop.location.id = self._device_id + self._prop = prop + self._outstanding_handles = set() + + def allocate(self) -> drv.CUmemGenericAllocationHandle: + handle: drv.CUmemGenericAllocationHandle = _unwrap( + drv.cuMemCreate(self._size, self._prop, 0) + ) + int_handle = int(handle) # pyright: ignore + assert (int_handle not in self._outstanding_handles) and int_handle != 0 + self._outstanding_handles.add(int_handle) + return handle + + def release(self, handle: drv.CUmemGenericAllocationHandle) -> None: + if handle == drv.CUmemGenericAllocationHandle(0): + return + assert int(handle) in self._outstanding_handles + self._outstanding_handles.remove(int(handle)) + try: + _unwrap(drv.cuMemRelease(handle)) + except: + print( + f"Failed to release handle {handle}. num_oustanding = {len(self._outstanding_handles)}" + ) + raise + + @property + def device_id(self) -> int: + return self._device_id + + @property + def size(self) -> int: + return self._size + + +class PhysMem(ItemHolderWithSharedPool[drv.CUmemGenericAllocationHandle]): + __slots__ = () + + +class PooledPhysMemAllocator(PooledFactoryBase[drv.CUmemGenericAllocationHandle, PhysMem]): + _Holder: Type[PhysMem] = PhysMem + __slots__ = ("device_id", "phys_mem_size") + device_id: int + phys_mem_size: int + + def __init__(self, phys_mem_size: int) -> None: + raw_alloc = NativePhysMemAllocator(phys_mem_size) + self.device_id = raw_alloc.device_id + self.phys_mem_size = phys_mem_size + super().__init__(lambda: raw_alloc.allocate(), lambda handle: raw_alloc.release(handle)) + + +# Virtual memory +class VirtMem: + __slots__ = ("_vm_size", "_allocator", "_address", "_pm_stack", "_access_desc") + _vm_size: int + _allocator: PooledPhysMemAllocator + _address: drv.CUdeviceptr + _pm_stack: list[PhysMem] + _access_desc: drv.CUmemAccessDesc + + def __init__( + self, vm_size: int, phys_mem_allocator: PooledPhysMemAllocator, init_num_phys_mem: int = 0 + ): + assert vm_size % phys_mem_allocator.phys_mem_size == 0 + self._allocator = phys_mem_allocator + device_id = phys_mem_allocator.device_id + self._address = _unwrap(drv.cuMemAddressReserve(vm_size, 0, 0, 0)) + self._vm_size = vm_size + self._pm_stack = [] + self._access_desc = drv.CUmemAccessDesc() + self._access_desc.location.type = drv.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE + self._access_desc.location.id = device_id + self._access_desc.flags = drv.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE + self.extend(init_num_phys_mem) + + @property + def phys_mem_size(self) -> int: + return self._allocator.phys_mem_size + + def destroy(self) -> None: + if self._vm_size == 0: + return + while self._pm_stack: + self._pop().close() + _unwrap(drv.cuMemAddressFree(self._address, self._vm_size)) + self._address = drv.CUdeviceptr(0) + self._vm_size = 0 + + def __del__(self) -> None: + self.destroy() + + def extend(self, num_phys_mem: int) -> None: + old_num_phys_mem = self.num_phys_mem + try: + for _ in range(num_phys_mem): + self._push(self._allocator.create()) + except ( + Exception + ): # to make realloc behave like normal realloc, we need to rollback if out of memory + while self.num_phys_mem > old_num_phys_mem: + self._pop().close() + raise + + def shrink(self, num_phys_mem: int) -> None: + for _ in range(num_phys_mem): + self._pop().close() + + # Different from normal realloc, this function never changes the pointer. + def realloc(self, num_bytes: int) -> None: + required_num_phys_mem = div_up(num_bytes, self.phys_mem_size) + if required_num_phys_mem > self.num_phys_mem: + self.extend(required_num_phys_mem - self.num_phys_mem) + elif required_num_phys_mem < self.num_phys_mem: + self.shrink(self.num_phys_mem - required_num_phys_mem) + + def _push(self, phy_mem: PhysMem) -> None: + phys_mem_size = self.phys_mem_size + assert phys_mem_size * (len(self._pm_stack) + 1) <= self._vm_size + vm_ptr = drv.CUdeviceptr(self.address + phys_mem_size * len(self._pm_stack)) + _unwrap(drv.cuMemMap(vm_ptr, phys_mem_size, 0, phy_mem.handle, 0)) + _unwrap(drv.cuMemSetAccess(vm_ptr, phys_mem_size, (self._access_desc,), 1)) + self._pm_stack.append(phy_mem) + + def _pop(self) -> PhysMem: + assert self._pm_stack + phys_mem_size = self.phys_mem_size + vm_ptr = drv.CUdeviceptr(self.address + phys_mem_size * (len(self._pm_stack) - 1)) + _unwrap(drv.cuMemUnmap(vm_ptr, phys_mem_size)) + return self._pm_stack.pop() + + @property + def mapped_bytes(self) -> int: + return self.phys_mem_size * self.num_phys_mem + + @property + def virtual_bytes(self) -> int: + return self._vm_size + + @property + def num_phys_mem(self) -> int: + return len(self._pm_stack) + + @property + def address(self) -> MemAddress: + return MemAddress(int(self._address)) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_eviction_controller/__init__.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_eviction_controller/__init__.py new file mode 100644 index 0000000000..0e22e1e7c6 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_eviction_controller/__init__.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._eviction_controller import ( # noqa: E402 + EvictablePage, + EvictionPolicy, + NodeRef, + PerLevelEvictionController, +) + +__all__ = ["EvictionPolicy", "PerLevelEvictionController", "EvictablePage", "NodeRef"] diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_eviction_controller/_eviction_controller.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_eviction_controller/_eviction_controller.py new file mode 100644 index 0000000000..6abeae94fd --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_eviction_controller/_eviction_controller.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Protocol, cast + +from llist import sllist, sllistnode + +from .._common import NDEBUG, CacheLevel, PageStatus, Priority +from .._exceptions import OutOfPagesError +from .._life_cycle_registry import LifeCycleId +from .._storage._core import PoolGroupIndex +from .._utils import ( + TypedIndexList, + assert_critical, + make_typed, + noexcept, + typed_enumerate, + typed_len, + unwrap_optional, +) + + +# @runtime_checkable +class EvictablePage(Protocol): + @property + def cache_level(self) -> CacheLevel: ... + + @property + def priority(self) -> Priority: ... + + @property + def life_cycle(self) -> LifeCycleId: ... + + @property + def status(self) -> PageStatus: ... + + def is_committed(self) -> bool: ... + + node_ref: "NodeRef | None" + + +# @runtime_checkable +class NodeRef(Protocol): + @property + def value(self) -> EvictablePage: ... + + +# @runtime_checkable +class EvictionPolicy(Protocol): + def push(self, page: EvictablePage, evict_first: bool = False) -> NodeRef: ... + + def pop(self) -> EvictablePage: ... + + # Remove a node so we no longer consider it for eviction. Like pop() but allow removing a node + # that is not the first. + def remove(self, node: NodeRef) -> EvictablePage: ... + + def __len__(self) -> int: ... + + +class LRUEvictionPolicy: + __slots__ = ("_queue",) + _queue: sllist + + def __init__(self) -> None: + self._queue = sllist() + + def push(self, page: EvictablePage, evict_first: bool = False) -> sllistnode: + assert page.node_ref is None + return self._queue.appendleft(page) if evict_first else self._queue.append(page) + + def pop(self) -> EvictablePage: + victim = self._queue.first + assert victim is not None + page = victim.value + self.remove(victim) + return page + + def remove(self, node: sllistnode) -> EvictablePage: + # assert isinstance(node, NodeRef) # mypyc does not support runtime_checkable + assert node == node.value.node_ref + return self._queue.remove(node) + + def __len__(self) -> int: + return len(self._queue) + + +# helper class to help add support for priority-based eviction +class PrioritizedEvictionPolicy: + __slots__ = ( + "_policy_creator", + "_policies", + ) + _policy_creator: Callable[[Priority], EvictionPolicy] + _policies: dict[Priority, EvictionPolicy] + + def __init__(self, policy_creator: Callable[[Priority], EvictionPolicy]) -> None: + self._policy_creator = policy_creator + self._policies = {} + + def __len__(self) -> int: + return sum(len(policy) for policy in self._policies.values()) + + def get_policy(self, priority: Priority) -> EvictionPolicy: + if priority not in self._policies: + self._policies[priority] = self._policy_creator(priority) + self._policies = dict(sorted(self._policies.items())) + return self._policies[priority] + + def _front_policy(self) -> EvictionPolicy: + return next(iter(self._policies.values())) + + def push(self, page: EvictablePage, evict_first: bool = False) -> NodeRef: + return self.get_policy(page.priority).push(page, evict_first) + + def pop(self) -> EvictablePage: + return self._front_policy().pop() + + def remove(self, node: NodeRef) -> EvictablePage: + page = node.value + policy = self._policies[page.priority] + policy.remove(node) + if not policy: + self._policies.pop(page.priority) + return page + + +class PrioritizedLRUEvictionPolicy(PrioritizedEvictionPolicy): + __slots__ = () + + def __init__(self) -> None: + super().__init__(lambda priority: LRUEvictionPolicy()) + + +class PerLevelEvictionController: # for one cache level + __slots__ = ("_life_cycle_grouping", "_policies", "_cache_level") + _life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex] + _policies: TypedIndexList[PoolGroupIndex, EvictionPolicy] + _cache_level: CacheLevel + + def __init__( + self, + life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex], + cache_level: CacheLevel, + ): + self._cache_level = cache_level + self._life_cycle_grouping = life_cycle_grouping + num_pool_groups = max(life_cycle_grouping) + 1 + assert num_pool_groups == len(set(life_cycle_grouping)) + self._policies = cast( + TypedIndexList, [PrioritizedLRUEvictionPolicy() for _ in range(num_pool_groups)] + ) + + def __del__(self) -> None: + if not NDEBUG: + assert_critical( + all(len(p) == 0 for p in self._policies), "Eviction controller is not empty" + ) + + def _get_policy(self, life_cycle: LifeCycleId) -> EvictionPolicy: + pg_idx = self._life_cycle_grouping[life_cycle] + return self._policies[pg_idx] + + def schedule_for_eviction(self, page: EvictablePage, evict_first: bool = False): + assert page.node_ref is None and page.cache_level == self._cache_level + page.node_ref = self._get_policy(page.life_cycle).push(page, evict_first) + assert unwrap_optional(page.node_ref).value is page + + # If evicting a node makes some other nodes useless, those nodes will be returned as well. + # One example: for SWA, if the number of blocks just makes up one window size, then evicting any of + # them makes the remaining blocks useless. + # Raise if no enough pages to evict. In this case, pages are returned to the eviction queue. + def evict( + self, min_num_pages: TypedIndexList[PoolGroupIndex, int] + ) -> TypedIndexList[PoolGroupIndex, list[EvictablePage]]: + assert NDEBUG or len(min_num_pages) == self.num_pool_groups + ret = make_typed(lambda: list[EvictablePage](), self.num_pool_groups) + try: + for pg_idx, count in typed_enumerate(min_num_pages): + policy = self._policies[pg_idx] + if (len(policy) + len(ret[pg_idx])) < count: + raise OutOfPagesError(f"Not enough pages to evict in group {pg_idx}") + while len(ret[pg_idx]) < count: + page = policy.pop() + page.node_ref = None + ret[pg_idx].append(page) + for a, b in zip(ret, self._evict_dependencies(page)): + a.extend(b) + except Exception: + for p in reversed(sum(ret, [])): + self.schedule_for_eviction(p, evict_first=True) + raise + assert all(p.cache_level == self._cache_level for p in sum(ret, [])), ( + "Corrupted eviction controller" + ) + return ret + + def remove(self, node: NodeRef) -> None: + page = node.value + assert page.node_ref == node + self._get_policy(page.life_cycle).remove(node) + page.node_ref = None + + # @TODO: implement this + @noexcept + def _evict_dependencies( + self, page: EvictablePage + ) -> TypedIndexList[PoolGroupIndex, list[EvictablePage]]: + return make_typed(lambda: list[EvictablePage](), self.num_pool_groups) + + def num_evictable_pages(self, pg_idx: PoolGroupIndex) -> int: + return len(self._policies[pg_idx]) + + @property + def num_pool_groups(self) -> PoolGroupIndex: + return typed_len(self._policies) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_exceptions.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_exceptions.py new file mode 100644 index 0000000000..20c59e9030 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_exceptions.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cuda.bindings.driver as drv + + +class OutOfMemoryError(Exception): + pass + + +class HostOOMError(OutOfMemoryError): + pass + + +class DiskOOMError(OutOfMemoryError): + pass + + +class CuOOMError(OutOfMemoryError): + pass + + +class LogicError(Exception): + """ + This exception indicates a bug in the code. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + + +class CuError(Exception): + error_code: drv.CUresult + + def __init__(self, error_code: drv.CUresult) -> None: + self.error_code = error_code + err, err_str = drv.cuGetErrorString(error_code) + if err != drv.CUresult.CUDA_SUCCESS: + err_str = "" + super().__init__(f"CUDA driver error: {error_code} ({err_str})") + + +class ResourceBusyError(Exception): + pass + + +class OutOfPagesError(Exception): + pass diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_life_cycle_registry.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_life_cycle_registry.py new file mode 100644 index 0000000000..a4cbe2c388 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_life_cycle_registry.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterator, NamedTuple, NewType, TypeAlias, cast + +from ._common import SlidingWindowSize +from ._config import KVCacheManagerConfig +from ._utils import TypedIndexList, div_up, typed_enumerate + + +class LifeCycle(NamedTuple): + window_size: SlidingWindowSize + num_sink_blocks: int # div_up(num_sink_tokens, tokens_per_block) + + @staticmethod + def make( + window_size: SlidingWindowSize, num_sink_tokens: int | None, tokens_per_block: int + ) -> "LifeCycle": + assert tokens_per_block > 0 + assert window_size is None or window_size > 0 + assert num_sink_tokens is None or num_sink_tokens >= 0 + assert num_sink_tokens in (None, 0) or window_size is not None + num_sink_blocks = div_up(num_sink_tokens or 0, tokens_per_block) + return LifeCycle(window_size, num_sink_blocks) + + +LifeCycleId = NewType("LifeCycleId", int) + +# For public exposure +LayerGroupId: TypeAlias = LifeCycleId + + +class LifeCycleRegistry: + __slots__ = ("_life_cycle_list", "_life_cycle_id_dict") + _life_cycle_list: TypedIndexList[LifeCycleId, LifeCycle] + _life_cycle_id_dict: dict[LifeCycle, LifeCycleId] + + def __init__(self, config: KVCacheManagerConfig) -> None: + self._life_cycle_list = cast(TypedIndexList[LifeCycleId, LifeCycle], []) + self._life_cycle_id_dict = dict[LifeCycle, LifeCycleId]() + for layer in config.layers: + details = LifeCycle.make( + layer.window_size, layer.num_sink_tokens, config.tokens_per_block + ) + if details not in self._life_cycle_id_dict: + assert len(self._life_cycle_id_dict) == len(self._life_cycle_list), ( + "corrupted life cycle registry" + ) + self._life_cycle_list.append(details) + self._life_cycle_id_dict[details] = LifeCycleId(len(self._life_cycle_list) - 1) + + def get_life_cycle(self, id: LifeCycleId) -> LifeCycle: + return self._life_cycle_list[id] + + def get_id(self, life_cycle_details: LifeCycle) -> LifeCycleId: + return self._life_cycle_id_dict[life_cycle_details] + + @property + def size(self) -> LifeCycleId: + assert len(self._life_cycle_list) == len(self._life_cycle_id_dict), ( + "corrupted life cycle registry" + ) + return LifeCycleId(len(self._life_cycle_list)) + + def __iter__(self) -> Iterator[LifeCycle]: + return iter(self._life_cycle_list) + + def __getitem__(self, idx: LifeCycleId) -> LifeCycle: + return self._life_cycle_list[idx] + + def items(self) -> Iterator[tuple[LifeCycleId, LifeCycle]]: + return typed_enumerate(self.get()) + + def get(self) -> TypedIndexList[LifeCycleId, LifeCycle]: + return self._life_cycle_list + + def __contains__(self, lc: LifeCycle) -> bool: + return lc in self._life_cycle_id_dict diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_page.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_page.py new file mode 100644 index 0000000000..328d26e116 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_page.py @@ -0,0 +1,463 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, NamedTuple, cast + +from . import rawref +from ._block_radix_tree import Block +from ._common import ( + BAD_PAGE_INDEX, + GPU_LEVEL, + NDEBUG, + BeamIndex, + BlockOrdinal, + CacheLevel, + PageIndex, + PageStatus, + Priority, + TokenIdExt, +) + +if TYPE_CHECKING: + from ._core._kv_cache import _KVCache + from ._storage_manager import StorageManager + + +from ._eviction_controller import NodeRef +from ._exceptions import LogicError +from ._life_cycle_registry import LifeCycleId +from ._storage._core import Slot +from ._utils import ( + CachedCudaEvent, + assert_critical, + filled_list, + get_uniform_attribute, + merge_events, + partition, + stream_wait_events, + unwrap_rawref, +) + +ReferenceType = rawref.ReferenceType + + +# We will have a huge amount of pages for large storage capacity. +# So we prefer inheritance over composition to save some memory. +@dataclass(slots=True) +class Page(Slot): + _manager: ReferenceType["StorageManager"] + life_cycle: LifeCycleId + cache_level: CacheLevel + _priority: Priority + # _holder is either None or a valid rawref. + _holder: ReferenceType["_PageHolder"] | None + node_ref: NodeRef | None + + def __del__(self) -> None: + if not NDEBUG: + assert_critical(self.status == PageStatus.DROPPABLE and not self.scheduled_for_eviction) + if self.has_valid_slot: + self.manager.release_slot(self.life_cycle, self.cache_level, self) + + @property + def manager(self) -> "StorageManager": + return unwrap_rawref(self._manager) + + @property + def priority(self) -> Priority: + return self._priority + + # prevent dropping + def hold(self) -> "_PageHolder": + if self._holder is not None: + return unwrap_rawref(self._holder) + holder = _PageHolder(self) + self._holder = rawref.ref(holder) + controller = self.manager + if self.scheduled_for_eviction and not controller.is_evictable(self): + controller.exclude_from_eviction(self) + assert not self.scheduled_for_eviction + return holder + + # Prevent eviction. You need to migrate the page to GPU later. + def lock( + self, + kv_cache: "_KVCache", + beam_index: BeamIndex, + ordinal: BlockOrdinal, + life_cycle: LifeCycleId, + skip_wait: bool = False, + ) -> "_SharedPageLock": + "If skip wait, you are responsible for making the page ready in kv_cache.cuda_stream." + return self.hold().lock(kv_cache, beam_index, ordinal, life_cycle, skip_wait) + + @property + def status(self) -> PageStatus: + if self._holder is None: + return PageStatus.DROPPABLE + lock_ref = unwrap_rawref(self._holder)._lock + if lock_ref is None: + return PageStatus.HELD + assert unwrap_rawref(lock_ref) is not None + return PageStatus.LOCKED + + @property + def scheduled_for_eviction(self) -> bool: + return self.node_ref is not None + + def is_committed(self) -> bool: + raise LogicError("Unexpected call to this implementation.") + + +@dataclass(slots=True) +class UncommittedPage(Page): + # @TODO: consider move this to _PageHolder + kv_cache: rawref.ref["_KVCache"] + ordinal: BlockOrdinal + beam_index: BeamIndex + + tokens: list[TokenIdExt] = field(default_factory=list) + + def is_committed(self) -> bool: + return False + + def __init__( + self, + kv_cache: "_KVCache", + ordinal: BlockOrdinal, + life_cycle: LifeCycleId, + cache_level: CacheLevel, + slot: Slot, + beam_index: BeamIndex = BeamIndex(0), + ): + self.kv_cache = rawref.ref(kv_cache) + self.ordinal = ordinal + self.beam_index = beam_index + manager = kv_cache.manager + priority = kv_cache._get_priority(ordinal, manager._life_cycles.get_life_cycle(life_cycle)) + Page.__init__( + self, + None, + CachedCudaEvent.NULL, + rawref.ref(manager._storage), + life_cycle, + cache_level, + priority, + None, + None, + ) + self.set_slot(slot) + + def convert_to_committed(self, block: Block) -> "CommittedPage": + """ + Moves the slot to a new committed page and add the new page to the block. + The uncommitted page becomes invalid. + """ + assert not self.scheduled_for_eviction + assert block.storage[self.life_cycle] is None + # If you hit this assertion failure, it's likely because you are using debugpy, which delayed GC + # for _KVCache._take_uncommitted_page(). Disable breakpoints on exceptions to avoid this issue. + assert self.status == PageStatus.DROPPABLE, "Release holder/lock first" + committed_page = CommittedPage( + self.manager, block, self.life_cycle, self.cache_level, self, self.priority + ) + self._slot_id = None + self.ready_event = CachedCudaEvent.NULL + assert committed_page.has_valid_slot + block.storage[self.life_cycle] = rawref.ref(committed_page) + return committed_page + + def __del__(self) -> None: + def check_page(p: "BlockPage") -> bool: + return p is None or isinstance(p.page, CommittedPage) + + if not NDEBUG: + assert_critical( + len(unwrap_rawref(self.kv_cache)._blocks) <= self.ordinal + or check_page( + unwrap_rawref(self.kv_cache) + ._blocks[self.ordinal] + .pages[self.beam_index][self.life_cycle] + ) + ) + Page.__del__(self) + + +@dataclass(slots=True) +class CommittedPage(Page): + block: rawref.ref["Block"] + __rawref__: rawref.ref["CommittedPage"] + + def is_committed(self) -> bool: + return True + + def __init__( + self, + storage: "StorageManager", + block: Block, + life_cycle: LifeCycleId, + cache_level: CacheLevel, + slot: Slot, + priority: Priority, + ): + self.block = rawref.ref(block) + self.__rawref__ = rawref.NULL + Page.__init__( + self, + None, + CachedCudaEvent.NULL, + rawref.ref(storage), + life_cycle, + cache_level, + priority, + None, + None, + ) + self.set_slot(slot) + + def __del__(self) -> None: + block = self.block() + # block may be None when rebase happens, i.e. another block with the same key is committed, + # replacing it, but the page is still used by a _KVCache. + if block is not None: + block.unset_page( + self.life_cycle, + self.manager._life_cycles.get_life_cycle(self.life_cycle), + ) + Page.__del__(self) + self.__rawref__.invalidate() + + +@dataclass(slots=True) +class _PageHolder: + "Prevents pages from being dropped." + + page: Page + _lock: rawref.ref["_UniqPageLock"] | None = None + __rawref__: rawref.ref["_PageHolder"] = field(default_factory=lambda: rawref.NULL) + + def __init__(self, page: Page) -> None: + self.page = page + self._lock = None + self.__rawref__ = rawref.NULL + + def __del__(self) -> None: + if not NDEBUG: + assert_critical(self._lock is None) + page = self.page + page._holder = None + # If a held page was in last level cache, it was not scheduled for eviction. + if page.is_committed(): + page = cast(CommittedPage, page) + if not page.scheduled_for_eviction: + page.manager.schedule_for_eviction(page) + block = page.block() + if block is None or block.is_orphan: + page.manager.exclude_from_eviction(page) + elif page.scheduled_for_eviction: + page = cast(UncommittedPage, self.page) + page.manager.exclude_from_eviction(self.page) + self.__rawref__.invalidate() + + # Prevent eviction. You need to migrate the page to GPU later. + def lock( + self, + kv_cache: "_KVCache", + beam_index: BeamIndex, + ordinal: BlockOrdinal, + life_cycle: LifeCycleId, + skip_wait: bool = False, + ) -> "_SharedPageLock": + if self._lock is None: + lock = _UniqPageLock(self) + self._lock = rawref.ref(lock) + else: + lock = unwrap_rawref(self._lock) + if self.page.scheduled_for_eviction: + manager = self.page.manager + manager.exclude_from_eviction(self.page) + assert not self.page.scheduled_for_eviction + return lock.share(kv_cache, beam_index, ordinal, life_cycle, skip_wait) + + +@dataclass(slots=True) +class _UniqPageLock: + "Locks pages to prevent eviction." + + holder: _PageHolder | None + finish_events: list[CachedCudaEvent] + __rawref__: rawref.ref["_UniqPageLock"] = field(default_factory=lambda: rawref.NULL) + + def __init__(self, holder: _PageHolder) -> None: + if holder.page.cache_level != CacheLevel(0): + raise ValueError("Lock can be applied only on GPU memory pages.") + self.holder = holder + self.finish_events = [] + self.__rawref__ = rawref.NULL + + def share( + self, + kv_cache: "_KVCache", + beam_index: BeamIndex, + ordinal: BlockOrdinal, + life_cycle: LifeCycleId, + skip_wait: bool, + ) -> "_SharedPageLock": + ret = _SharedPageLock(self, kv_cache, beam_index, ordinal, life_cycle, skip_wait) + return ret + + @property + def page(self) -> Page: + assert self.holder is not None + return self.holder.page + + def __del__(self) -> None: + page = self.page + if not NDEBUG: + assert_critical(page.cache_level == CacheLevel(0) and not page.scheduled_for_eviction) + page.ready_event = merge_events(self.finish_events) + assert self.holder is not None + self.holder._lock = None + if False: + if page.manager.is_evictable(page): + page.manager.schedule_for_eviction(page) + else: + # Optimized code path: + # delete holder first, so if nobody holds the page elsewhere, it becomes droppable immediately, + # before we hand it over to eviction controller. + self.holder = None + # if it's not droppable, then it means self.holder=None had no impact. We need to schedule it + # for eviction as usual. + if page.status != PageStatus.DROPPABLE and page.manager.is_evictable(page): + page.manager.schedule_for_eviction(page) + self.__rawref__.invalidate() + + +class LockOwner(NamedTuple): + kv_cache: rawref.ref["_KVCache"] + beam_index: BeamIndex + ordinal: BlockOrdinal + life_cycle: LifeCycleId + + +@dataclass(slots=True, init=False) +class _SharedPageLock: + _uniq_lock: _UniqPageLock | None + _user: LockOwner + + @property + def page(self) -> Page: + assert self._uniq_lock is not None + return self._uniq_lock.page + + @property + def holder(self) -> _PageHolder: + assert self._uniq_lock is not None + assert self._uniq_lock.holder is not None + return self._uniq_lock.holder + + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, other: object) -> bool: + return self is other + + def __init__( + self, + uniq_lock: _UniqPageLock, + kv_cache: "_KVCache", + beam_index: BeamIndex, + ordinal: BlockOrdinal, + life_cycle: LifeCycleId, + skip_wait: bool, + ) -> None: + self._uniq_lock = uniq_lock + if not skip_wait: + self.page.ready_event.wait_in_stream(kv_cache.cuda_stream) + self._user = LockOwner(rawref.ref(kv_cache), beam_index, ordinal, life_cycle) + new_index = self._get_page_index() + old_index = kv_cache._update_page_index(beam_index, ordinal, life_cycle, new_index) + assert old_index == BAD_PAGE_INDEX + + def __del__(self) -> None: + if self._uniq_lock is not None: + self.unlock() + + def unlock(self) -> Page: + assert self._uniq_lock is not None + page = self.page + self._uniq_lock.finish_events.append(unwrap_rawref(self._user.kv_cache).finish_event) + new_index = BAD_PAGE_INDEX + old_index = unwrap_rawref(self._user.kv_cache)._update_page_index( + self._user.beam_index, self._user.ordinal, self._user.life_cycle, new_index + ) + assert NDEBUG or old_index == self._get_page_index() + self._uniq_lock = None + return page + + def _get_page_index(self) -> PageIndex: + storage = unwrap_rawref(self._user.kv_cache).manager._storage + num_buffers_per_slot = storage._slot_to_page_indices[self._user.life_cycle] + return PageIndex(self.page.slot_id * num_buffers_per_slot) + + +BlockPage = _SharedPageLock | _PageHolder | None + + +class BatchedLockTarget(NamedTuple): + page: Page + beam_index: BeamIndex + ordinal: BlockOrdinal + life_cycle: LifeCycleId + + +def batched_lock_to_gpu( + kv_cache: "_KVCache", tasks: Sequence[BatchedLockTarget] +) -> list["_SharedPageLock"]: + "Lock pages after migrating all pages to GPU. If migration fails, no locking happens." + storage = kv_cache.manager._storage + assert not tasks or storage is get_uniform_attribute(tasks, lambda p: p.page.manager) + requirements = filled_list(0, storage.num_pool_groups) + scheduled_for_eviction = [t.page.scheduled_for_eviction for t in tasks] + for t, e in zip(tasks, scheduled_for_eviction): + if e: + storage.exclude_from_eviction(t.page) + if t.page.cache_level == GPU_LEVEL: + continue + requirements[storage.get_pool_group_index(t.life_cycle)] += 1 + + try: + storage.prepare_free_slots(GPU_LEVEL, requirements) + partitioned = partition( + tasks, lambda p: (p.page.cache_level, storage.get_pool_group_index(p.life_cycle)) + ) + for (lvl, pg_idx), part in partitioned.items(): + if lvl == GPU_LEVEL: + continue + storage._batched_migrate( + pg_idx, GPU_LEVEL, lvl, [p.page for p in part], update_src=True + ) + except Exception: + for t, e in zip(tasks, scheduled_for_eviction): + if e: + storage.schedule_for_eviction(t.page) + raise + stream_wait_events(kv_cache.cuda_stream, (p.page.ready_event for p in tasks)) + return [ + page.lock(kv_cache, beam_index, ordinal, life_cycle, skip_wait=True) + for page, beam_index, ordinal, life_cycle in tasks + ] diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/__init__.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/__init__.py new file mode 100644 index 0000000000..53e013fd96 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._core import CacheLevelStorage + +# These are re-exported for external use +__all__ = ["CacheLevelStorage"] diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_config.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_config.py new file mode 100644 index 0000000000..7bdecaf69a --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_config.py @@ -0,0 +1,225 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from dataclasses import dataclass +from typing import NamedTuple, cast + +from .._common import LayerId +from .._config import CacheTierConfig, DataRole, KVCacheManagerConfig +from .._life_cycle_registry import LifeCycle, LifeCycleId, LifeCycleRegistry +from .._storage._core import PoolGroupIndex, PoolIndex +from .._utils import ( + HomoTuple, + TypedIndexList, + exact_div, + filled_list, + get_uniform_attribute, + is_sorted, + typed_range, +) + + +class BufferId(NamedTuple): + layer_id: LayerId + role: DataRole + + +@dataclass(slots=True) +class CoalescedBuffer: + life_cycle_id: LifeCycleId + single_buffer_size: int # identical for all buffers in the same coalesced buffer + buffer_ids: list[BufferId] + + @property + def size(self) -> int: + return self.single_buffer_size * len(self.buffer_ids) + + +@dataclass(slots=True) +class PageConfig: + """ + A page is a group of coalesced buffers. Each coalesced buffer has multiple buffers with the same + size. Multiple coalesced buffers can be in the same page if they share the same life cycle and + coalesced size. + """ + + coalesced_buffers: list[CoalescedBuffer] + + @property + def _coalesced_size(self) -> int: + return get_uniform_attribute(self.coalesced_buffers, lambda b: b.size) + + @property + def slot_size(self) -> int: + return self._coalesced_size * len(self.coalesced_buffers) + + @property + def life_cycle_id(self) -> LifeCycleId: + return get_uniform_attribute(self.coalesced_buffers, lambda b: b.life_cycle_id) + + +@dataclass(slots=True, frozen=True) +class SlotConfig: + "A group of pages for the same life cycle." + + pages: HomoTuple[PageConfig] + + def __post_init__(self) -> None: + assert is_sorted(self.pages, key=lambda s: s.slot_size, reverse=True) + assert all( + len(p.coalesced_buffers) == len(self.pages[0].coalesced_buffers) for p in self.pages + ) + + @property + def life_cycle_id(self) -> LifeCycleId: + return get_uniform_attribute(self.pages, lambda s: s.life_cycle_id) + + @property + def slot_size_list(self) -> HomoTuple[int]: + return tuple(s.slot_size for s in self.pages) + + +@dataclass(slots=True, frozen=True) +class PoolGroupConfig: + """ + A group of pools may contain slots (page groups) with different life cycles. They have identical + slot size list, so we can put them in the same group of memory pools. + """ + + slots: HomoTuple[SlotConfig] + + @property + def slot_size_list(self) -> HomoTuple[int]: + return get_uniform_attribute(self.slots, lambda s: s.slot_size_list) + + +@dataclass(slots=True, frozen=True) +class BufferAttr: + life_cycle_id: LifeCycleId + pool_index: PoolIndex + offset: int + size: int + + +@dataclass(slots=True, frozen=True) +class StorageConfig: + cache_tiers: HomoTuple[CacheTierConfig] + pool_groups: HomoTuple[PoolGroupConfig] + + @property + def num_life_cycles(self) -> LifeCycleId: + return LifeCycleId(sum(len(pg.slots) for pg in self.pool_groups)) + + def life_cycle_grouping(self) -> TypedIndexList[LifeCycleId, PoolGroupIndex]: + ret = filled_list(PoolGroupIndex(-1), self.num_life_cycles) + for pg_idx, pg in enumerate(self.pool_groups): + pg_idx = PoolGroupIndex(pg_idx) + for s in pg.slots: + ret[s.life_cycle_id] = pg_idx + return ret + + def buffer_attributes(self) -> dict[BufferId, BufferAttr]: + ret = dict[BufferId, BufferAttr]() + for pg in self.pool_groups: + for slot in pg.slots: + life_cycle_id = slot.life_cycle_id + for pool, page in enumerate(slot.pages): + offset = 0 + for cb in page.coalesced_buffers: + for b in cb.buffer_ids: + ret[b] = BufferAttr( + life_cycle_id, PoolIndex(pool), offset, cb.single_buffer_size + ) + offset += cb.single_buffer_size + return ret + + def slot_to_page_indices(self) -> TypedIndexList[LifeCycleId, int]: + ret = filled_list(0, self.num_life_cycles) + for pg in self.pool_groups: + for slot in pg.slots: + life_cycle = slot.life_cycle_id + assert len(slot.pages) == 1 + page = slot.pages[0] + assert len(page.coalesced_buffers) == 1 + scale = exact_div(page.slot_size, page.coalesced_buffers[0].single_buffer_size) + ret[life_cycle] = scale + return ret + + def layer_to_life_cycle_ids(self) -> TypedIndexList[LayerId, LifeCycleId]: + map = dict[LayerId, LifeCycleId]() + for (layer_id, _), attr in self.buffer_attributes().items(): + lc_id = map.setdefault(layer_id, attr.life_cycle_id) + assert lc_id == attr.life_cycle_id + assert len(map) == max(map.keys()) + 1 + return cast( + TypedIndexList[LayerId, LifeCycleId], + [map[LayerId(layer_id)] for layer_id in typed_range(len(map))], + ) + + def __post_init__(self) -> None: + groups = [tuple(s.life_cycle_id for s in pg.slots) for pg in self.pool_groups] + all_life_cycle_ids = sum((g for g in groups), ()) + assert len(all_life_cycle_ids) == len(set(all_life_cycle_ids)) + + +def create_storage_config(config: KVCacheManagerConfig) -> StorageConfig: + # group buffers first by life cycle, then by single buffer size. + buffer_groups = defaultdict[LifeCycleId, defaultdict[int, list[BufferId]]]( + lambda: defaultdict[int, list[BufferId]](list[BufferId]) + ) + life_cycle_registry = LifeCycleRegistry(config) + for layer in config.layers: + life_cycle = LifeCycle.make( + layer.window_size, layer.num_sink_tokens, config.tokens_per_block + ) + life_cycle_id = life_cycle_registry.get_id(life_cycle) + size_to_buffers = buffer_groups[life_cycle_id] + for buffer in layer.buffers: + size_to_buffers[buffer.size].append(BufferId(layer.layer_id, buffer.role)) + # Create one slot group for each life cycle. + # It's possible that buffers with different sizes form coalesced buffers with the same coalesced size. + # @TODO: add test for this case. + slot_groups: list[SlotConfig] = [] + for life_cycle_id, size_to_buffers in buffer_groups.items(): + assert len(set(len(buffer_ids) for buffer_ids in size_to_buffers.values())) == 1, ( + "Not yet supported. While we can support this easily, we need to know whether the kernels " + "need to share page indices or not. We haven't seen such models, yet. So we leave this as a " + "future work." + ) + size_to_coalesced_buffers = defaultdict[int, list[CoalescedBuffer]](list[CoalescedBuffer]) + for size, buffer_ids in size_to_buffers.items(): + coalesced_size = size * len(buffer_ids) + coalesced_buffers = size_to_coalesced_buffers[coalesced_size] + coalesced_buffers.append( + CoalescedBuffer( + life_cycle_id=life_cycle_id, single_buffer_size=size, buffer_ids=buffer_ids + ) + ) + slots = [ + PageConfig(coalesced_buffers) + for coalesced_buffers in size_to_coalesced_buffers.values() + ] + slots.sort(key=lambda p: p.slot_size, reverse=True) + slot_groups.append(SlotConfig(tuple(slots))) + # Merge slot groups with the same slot_size_list + pool_groups_by_slot_size_list = defaultdict[HomoTuple[int], list[SlotConfig]](list[SlotConfig]) + for slot_group in slot_groups: + pool_groups_by_slot_size_list[slot_group.slot_size_list].append(slot_group) + pool_groups = [ + PoolGroupConfig(tuple(slot_groups)) + for slot_groups in pool_groups_by_slot_size_list.values() + ] + return StorageConfig(cache_tiers=tuple(config.cache_tiers), pool_groups=tuple(pool_groups)) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_core.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_core.py new file mode 100644 index 0000000000..4d66c18258 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage/_core.py @@ -0,0 +1,936 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import errno +import os +import sys +import tempfile +import warnings +from collections import deque +from collections.abc import Sequence +from dataclasses import dataclass +from typing import ClassVar, NewType, final + +if sys.version_info[:2] >= (3, 12): + from typing import override +else: + from typing_extensions import override + +from .._common import ( + BAD_FILE_DESCRIPTOR, + NDEBUG, + Address, + CacheTier, + DiskAddress, + FileDescriptor, + MemAddress, +) +from .._cuda_virt_mem import PooledPhysMemAllocator, VirtMem +from .._exceptions import LogicError, OutOfPagesError, ResourceBusyError +from .._utils import ( + CachedCudaEvent, + DynamicBitset, + HomoTuple, + HostMem, + assert_critical, + div_up, + query_total_gpu_memory, + remove_if, + resize_file, + round_down, + round_up, +) + +PoolGroupIndex = NewType("PoolGroupIndex", int) +PoolIndex = NewType("PoolIndex", int) +SlotId = NewType("SlotId", int) + + +class SlotPoolBase(abc.ABC): + _slot_size: int + + @property + def slot_size(self) -> int: + return self._slot_size + + @property + @abc.abstractmethod + def num_slots(self) -> int: ... + + @property + def num_bytes(self) -> int: + return self.slot_size * self.num_slots + + def __init__(self, slot_size: int) -> None: + self._slot_size = slot_size + + @abc.abstractmethod + def destroy(self) -> None: + pass + + @abc.abstractmethod + def resize(self, new_num_slots: int) -> None: + pass + + @abc.abstractmethod + def slot_address(self, slot: int) -> Address: + pass + + def __del__(self) -> None: + self.destroy() + + +@final +class GpuSlotPool(SlotPoolBase): + __slots__ = ("_vm",) + _vm: VirtMem + + def __init__( + self, + slot_size: int, + vm_size: int, + shared_phys_mem_pool: PooledPhysMemAllocator, + num_slots: int, + ): + super().__init__(slot_size) + assert vm_size % shared_phys_mem_pool.phys_mem_size == 0 + self._vm = VirtMem(vm_size, shared_phys_mem_pool) + self.resize(num_slots) + + @override + def destroy(self) -> None: + self._vm.destroy() + + @override + def resize(self, new_num_slots: int) -> None: + new_num_phys_mem = self._compute_num_phys_mem( + self.slot_size, new_num_slots, self._vm.phys_mem_size + ) + self._vm.realloc(self._vm.phys_mem_size * new_num_phys_mem) + + def extend_by_one_phys_mem(self) -> int: + self._vm.extend(1) + return self.num_slots + + @override + def slot_address(self, slot: int) -> MemAddress: + return MemAddress(int(self._vm.address) + self.slot_size * slot) + + @property + @override + def num_slots(self) -> int: + return self._compute_num_slots( + self.slot_size, self._vm.num_phys_mem, self._vm.phys_mem_size + ) + + @staticmethod + def _compute_num_phys_mem(slot_size: int, num_slots: int, phys_mem_size: int) -> int: + return div_up(num_slots * slot_size, phys_mem_size) + + @staticmethod + def _compute_num_slots(slot_size: int, num_phys_mem: int, phys_mem_size: int) -> int: + return num_phys_mem * phys_mem_size // slot_size + + +class HostSlotPool(SlotPoolBase): + __slots__ = ("_host_mem",) + _host_mem: HostMem + + def __init__(self, slot_size: int, num_slots: int) -> None: + super().__init__(slot_size) + self._host_mem = HostMem(self.aligned_size(num_slots)) + + @override + def destroy(self) -> None: + self._host_mem.destroy() + + @override + def resize(self, new_num_slots: int) -> None: + self._host_mem.resize(self.aligned_size(new_num_slots)) + + @override + def slot_address(self, slot: int) -> MemAddress: + return MemAddress(self._host_mem._address + self.slot_size * slot) + + @property + @override + def num_slots(self) -> int: + return self._host_mem.size // self.slot_size + + def aligned_size(self, num_slots: int) -> int: + return round_up(num_slots * self.slot_size, HostMem.ALIGNMENT) + + +class DiskSlotPool(SlotPoolBase): + __slots__ = ("_filename", "_fd") + # Currently only used to get the parent folder where we create temporary files. + # You won't find file with this name. + filename: str + _fd: FileDescriptor + + def __init__(self, filename: str, slot_size: int, num_slots: int) -> None: + super().__init__(slot_size) + self.filename = filename + folder = os.path.dirname(filename) + assert os.path.isdir(folder), f"Folder {folder} does not exist" + try: + fd = os.open(folder, os.O_TMPFILE | os.O_RDWR | os.O_EXCL, 0o664) + except OSError as e: + if e.errno != errno.EOPNOTSUPP: + raise + # Fallback for filesystems/architectures not supporting O_TMPFILE + fd, path = tempfile.mkstemp(dir=folder) + try: + os.unlink(path) + except OSError: + os.close(fd) + raise + self._fd = FileDescriptor(fd) + self.resize(num_slots) + + @override + def destroy(self) -> None: + if self.fd == BAD_FILE_DESCRIPTOR: + return + os.close(self.fd) + self._fd = BAD_FILE_DESCRIPTOR + + @property + def fd(self) -> FileDescriptor: + return self._fd + + @property + def file_size(self) -> int: + return os.lseek(self.fd, 0, os.SEEK_END) + + @override + def resize(self, new_num_slots: int) -> None: + file_size = new_num_slots * self.slot_size + resize_file(self.fd, file_size) + + @override + def slot_address(self, slot: int) -> DiskAddress: + assert slot < self.num_slots + return DiskAddress(self.fd, slot * self.slot_size) + + @property + @override + def num_slots(self) -> int: + return self.file_size // self.slot_size + + +@dataclass(slots=True) +class Slot: + # ready_event indicates whether the slot is ready for use. + # For newly allocated BlockData, it indicates finish of the last usage by the previous owners of the + # slot (who returned the slot to the pool). + # After data migration, it indicates finish of data migration. + # When passed to release(), it indicates finish of usage by the current owners of the slot. + _slot_id: SlotId | None + ready_event: CachedCudaEvent + + @property + def slot_id(self) -> SlotId: + assert self._slot_id is not None + return self._slot_id + + def query_ready(self) -> bool: + ev = self.ready_event + if ev is CachedCudaEvent.NULL: + return True + ret = ev.query_complete() + if ret: + self.ready_event = CachedCudaEvent.NULL + return ret + + @property + def has_valid_slot(self) -> bool: + return self._slot_id is not None + + def move_to_new_slot(self) -> "Slot": + ret = Slot(None, CachedCudaEvent.NULL) + ret.set_slot(self) + return ret + + def set_slot(self, slot: "Slot") -> None: + if self.has_valid_slot: + raise LogicError("Slot is already set.") + self._slot_id = slot.slot_id + self.ready_event = slot.ready_event + slot._slot_id = None + slot.ready_event = CachedCudaEvent.NULL + + def __del__(self) -> None: + if self.has_valid_slot: + warnings.warn("[KVCacheManager] slot is not freed before deletion") + + +class SlotAllocator: + __slots__ = ( + "_capacity", + "_num_active_slots", + "_recycled_slots", + "_num_ready_recycled_slots", + "_occupied_mask", + "_target_capacity", + "_overflow_slots", + "_num_ready_overflow_slots", + ) + _capacity: int + _num_active_slots: int # active slots are either in use or recycled. + _recycled_slots: deque[ + Slot + ] # only store recycled slots to avoid excessive memory usage on program start + _num_ready_recycled_slots: int # number of recycled slots that are ready to be used immediately + # (no need for sync or wait in stream), i.e. their ready events are triggered. + _occupied_mask: DynamicBitset + + # for scheduled shrinking resize + _target_capacity: ( + int # _target_capacity <= _capacity. Inequal if a shrinking resize is in progress. + ) + _overflow_slots: list[ + Slot + ] # slots that will be out-of-range after a in-progress resize. scheduled for removal. + _num_ready_overflow_slots: int # similar to _num_ready_recycled_slots, but for _overflow_slots. + + def __init__(self, capacity: int) -> None: + self._capacity = capacity + self._num_active_slots = 0 + self._recycled_slots = deque[Slot]() + self._num_ready_recycled_slots = 0 + self._occupied_mask = DynamicBitset(capacity) + self._target_capacity = capacity + self._overflow_slots = [] + self._num_ready_overflow_slots = 0 + + def __del__(self) -> None: + assert_critical( + self._num_ready_recycled_slots == len(self._recycled_slots) + and self._num_ready_overflow_slots == len(self._overflow_slots), + "did you call synchronize()?", + ) + assert_critical( + self._target_capacity == self._capacity and not self._overflow_slots, + "resize is in progress", + ) + assert_critical(self._occupied_mask.num_set_bits == 0, "some slots are still in use") + assert_critical( + len(self._recycled_slots) == self._num_active_slots, "some slots are not free" + ) + + @property + def num_free_slots(self) -> int: + return len(self._recycled_slots) + max(self._target_capacity - self._num_active_slots, 0) + + @property + def num_occupied_slots(self) -> int: + return self._occupied_mask.num_set_bits + + def allocate(self) -> Slot: + if self.num_free_slots == 0: + raise OutOfPagesError("No free slots") + self._scrub_events() + # prefererence: ready recycled slots > new slots > recycled slots that are not ready + if self._num_ready_recycled_slots > 0: + assert self._recycled_slots + slot = self._recycled_slots.popleft() + assert slot.has_valid_slot + self._num_ready_recycled_slots -= 1 + assert slot.ready_event is CachedCudaEvent.NULL + elif self._num_active_slots < self.num_slots: + slot = Slot(SlotId(self._num_active_slots), CachedCudaEvent.NULL) + self._num_active_slots += 1 + else: + slot = self._recycled_slots.popleft() + assert slot.has_valid_slot + self._occupied_mask.set(slot.slot_id) + return slot + + # The reason why we don't use allocate() multiple times is that if what user need is all or none, + # and when we don't have enough free slots, we will free these newly allocated slots by appending + # them to the back of the recycled slot queue, which may impact perf. + def allocate_multiple(self, num_slots: int) -> list[Slot]: + if self.num_free_slots < num_slots: + raise OutOfPagesError("Not enough free slots") + return [self.allocate() for _ in range(num_slots)] + + def release(self, slot: Slot) -> None: + assert slot.has_valid_slot + slot = slot.move_to_new_slot() + if slot.slot_id >= self._capacity or not self._occupied_mask.get(slot.slot_id): + raise LogicError(f"Slot {slot.slot_id} is not occupied") + assert type(slot) is Slot and slot.has_valid_slot + if slot.slot_id < self._target_capacity: + self._recycled_slots.append(slot) + else: + self._overflow_slots.append(slot) + self._try_trigger_shrink() + self._occupied_mask.clear(slot.slot_id) + self._scrub_events() + assert NDEBUG or self._check() + + @property + def num_slots(self) -> int: + return self._capacity + + def resize(self, new_num_slots: int) -> None: + if self._target_capacity != self._capacity: + self.cancel_scheduled_resize() + assert NDEBUG or self._check() + old_num_slots = self.num_slots + if new_num_slots < self.num_slots and self._occupied_mask.any_set( + new_num_slots, self.num_slots + ): + raise ResourceBusyError("resize cannot remove occupied slots") + self._occupied_mask.resize(new_num_slots) + self._capacity = new_num_slots + self._num_active_slots = min(self._num_active_slots, new_num_slots) + if new_num_slots < old_num_slots: + new_recycled_slots = deque[Slot]() + new_num_ready_recycled_slots = 0 + for idx_recycled, slot in enumerate(self._recycled_slots): + assert type(slot) is Slot and slot.has_valid_slot + if slot.slot_id >= new_num_slots: + slot.ready_event.synchronize() + slot._slot_id = None + slot.ready_event = CachedCudaEvent.NULL + else: + new_recycled_slots.append(slot) + if idx_recycled < self._num_ready_recycled_slots: + new_num_ready_recycled_slots += 1 + self._recycled_slots = new_recycled_slots + self._num_ready_recycled_slots = new_num_ready_recycled_slots + self._scrub_events() + self._target_capacity = self._capacity + assert NDEBUG or self._check() + + def schedule_resize(self, new_num_slots: int) -> None: + assert NDEBUG or self._check() + if new_num_slots >= self.num_slots: + self.cancel_scheduled_resize() + self.resize(new_num_slots) + return + old_target_capacity = self._target_capacity + if new_num_slots > old_target_capacity: + self._recycled_slots.extend( + remove_if( + self._overflow_slots, + lambda slot: old_target_capacity <= slot.slot_id < new_num_slots, + ) + ) + self._num_ready_overflow_slots = 0 + if new_num_slots < old_target_capacity: + self._overflow_slots.extend( + remove_if(self._recycled_slots, lambda slot: slot.slot_id >= new_num_slots) + ) + self._num_ready_recycled_slots = 0 + self._target_capacity = new_num_slots + self._try_trigger_shrink() + self._scrub_events() + assert NDEBUG or self._check() + + def cancel_scheduled_resize(self) -> None: + assert NDEBUG or self._check() + self._target_capacity = self._capacity + self._recycled_slots.extend(remove_if(self._overflow_slots, lambda slot: True)) + self._num_ready_overflow_slots = 0 + + def shrink_in_progress(self) -> bool: + "Indicates if a scheduled shrink is in progress." + assert self._target_capacity <= self._capacity + return self._target_capacity < self._capacity + + def get_slots_blocking_shrink(self) -> HomoTuple[SlotId]: + return tuple( + SlotId(id) + for id in range(self._target_capacity, self._capacity) + if self._occupied_mask.get(id) + ) + + def _try_trigger_shrink(self) -> bool: + assert NDEBUG or self._check() + if ( + self.shrink_in_progress() + and self._target_capacity + len(self._overflow_slots) == self._capacity + ): + assert len(set(s.slot_id for s in self._overflow_slots)) == len(self._overflow_slots) + for slot in self._overflow_slots: + slot.ready_event.synchronize() + slot.ready_event = CachedCudaEvent.NULL + self._overflow_slots.clear() + self._num_ready_overflow_slots = 0 + self._capacity = self._target_capacity + self._num_active_slots = min(self._num_active_slots, self._capacity) + self._scrub_events() + assert NDEBUG or self._check() + return True + return False + + def _scrub_events(self) -> None: + self._num_ready_recycled_slots = self._scrub_events_impl( + self._recycled_slots, self._num_ready_recycled_slots + ) + self._num_ready_overflow_slots = self._scrub_events_impl( + self._overflow_slots, self._num_ready_overflow_slots + ) + + def _check(self) -> bool: + return ( + self._num_active_slots <= self._capacity + and self._target_capacity <= self._capacity + and (self.shrink_in_progress() or len(self._overflow_slots) == 0) + and all( + self._target_capacity <= slot.slot_id < self._capacity + for slot in self._overflow_slots + ) + and len(self._recycled_slots) + len(self._overflow_slots) + self.num_occupied_slots + == self._num_active_slots + ) + + @staticmethod + def _scrub_events_impl(slots: Sequence[Slot], num_ready: int) -> int: + assert num_ready <= len(slots) + for i in range(num_ready, len(slots)): + slot = slots[i] + if slot.ready_event.query_complete(): + slot.ready_event = CachedCudaEvent.NULL + num_ready += 1 + else: + break + return num_ready + + def _synchronize(self) -> None: + "synchronize the events of all unused slots" + while self._num_ready_recycled_slots != len( + self._recycled_slots + ) or self._num_ready_overflow_slots != len(self._overflow_slots): + self._scrub_events() + + +class PoolGroupBase: + __slots__ = ("_slot_allocator", "_pools") + + _slot_allocator: SlotAllocator + _pools: HomoTuple[SlotPoolBase] + + def __init__(self, num_slots: int) -> None: + self._slot_allocator = SlotAllocator(num_slots) + + def __del__(self) -> None: + self.destroy() + + def destroy(self) -> None: + if self._slot_allocator._capacity == 0: + return + self._slot_allocator._synchronize() + for pool in self._pools: + pool.destroy() + self._slot_allocator.resize(0) + + @property + def num_pools(self) -> PoolIndex: + return PoolIndex(len(self._pools)) + + @property + def num_slots(self) -> int: + num_slots = self._slot_allocator._capacity + assert num_slots <= self._get_num_slots_from_pools() + return num_slots + + @property + def num_free_slots(self) -> int: + return self._slot_allocator.num_free_slots + + @property + def num_bytes(self) -> int: + return sum(pool.num_bytes for pool in self._pools) + + def resize_slot_allocator(self, new_num_slots: int | None) -> None: + """ + Resize the slot allocator, but not pools. If new_num_slots is None, make slot allocator match the pool sizes. + """ + if new_num_slots is None: + new_num_slots = self._get_num_slots_from_pools() + self._slot_allocator.resize(new_num_slots) + assert NDEBUG or self._check(True) + + def resize_pools(self, new_num_slots: int | None) -> None: + """ + Resize the pools, but not the slot allocator. If new_num_slots is None, make pool sizes match + the slot allocator. + If exception is raised, size of pools may be imbalanced. Call resize_pools() again with None or + self._get_num_slots_from_pools() to fix it. + """ + if new_num_slots is None: + new_num_slots = self._slot_allocator.num_slots + for pool in self._pools: + pool.resize(new_num_slots) + assert NDEBUG or self._check(True) + + def allocate(self) -> Slot: + return self._slot_allocator.allocate() + + def allocate_multiple(self, num_slots: int) -> list[Slot]: + return self._slot_allocator.allocate_multiple(num_slots) + + def release(self, slot: Slot) -> None: + self._slot_allocator.release(slot) + + def slot_address(self, slot_id: SlotId) -> HomoTuple[Address]: + return tuple(pool.slot_address(slot_id) for pool in self._pools) + + @property + def slot_size(self) -> HomoTuple[int]: + return tuple(pool.slot_size for pool in self._pools) + + def _check(self, allow_mismatch: bool = False) -> bool: + pool_num_slots = self._get_num_slots_from_pools() + return ( + self._slot_allocator.num_slots <= pool_num_slots + if allow_mismatch + else self._slot_allocator.num_slots == pool_num_slots + ) + + def _get_num_slots_from_pools(self) -> int: + return min(p.num_slots for p in self._pools) + + @staticmethod + def _compute_num_phys_mem( + slot_size_list: Sequence[int], num_slots: int, phys_mem_size: int + ) -> HomoTuple[int]: + return tuple( + GpuSlotPool._compute_num_phys_mem(slot_size, num_slots, phys_mem_size) + for slot_size in slot_size_list + ) + + +class GpuPoolGroup(PoolGroupBase): + __slots__ = () + + def __init__( + self, + num_slots: int, + slot_size_list: Sequence[int], + shared_phys_mem_pool: PooledPhysMemAllocator, + ): + super().__init__(num_slots) + total_gpu_memory = query_total_gpu_memory() + max_slot_size = max(slot_size_list) + phys_mem_size = shared_phys_mem_pool.phys_mem_size + self._pools = tuple( + GpuSlotPool( + slot_size, + round_down(int(total_gpu_memory * slot_size / max_slot_size), phys_mem_size), + shared_phys_mem_pool, + num_slots, + ) + for slot_size in slot_size_list + ) + + +class HostPoolGroup(PoolGroupBase): + __slots__ = () + + def __init__(self, num_slots: int, slot_size_list: Sequence[int]): + super().__init__(num_slots) + self._pools = tuple(HostSlotPool(slot_size, num_slots) for slot_size in slot_size_list) + + +class DiskPoolGroup(PoolGroupBase): + __slots__ = () + + def __init__(self, num_slots: int, slot_size_list: Sequence[int], filename_template: str): + super().__init__(num_slots) + self._pools = tuple( + DiskSlotPool(filename_template.format(i), slot_size, num_slots) + for i, slot_size in enumerate(slot_size_list) + ) + + +class CacheLevelStorage: + TIER: ClassVar[CacheTier] + __slots__ = ("_total_quota", "_ratio_list", "_pool_groups") + _total_quota: int # fixme: remove _total_quota and _ratio_list and compute from _pool_groups + _ratio_list: HomoTuple[float] + _pool_groups: HomoTuple[PoolGroupBase] + + def __init__(self, total_quota: int, ratio_list: Sequence[float]) -> None: + if not hasattr(self.__class__, "TIER"): + raise ValueError(f"{self.__class__.__name__} must define 'TIER' as a class variable") + self._total_quota = total_quota + self._ratio_list = tuple(ratio_list) + + def __del__(self) -> None: + self.destroy() + + @property + def cache_tier(self) -> CacheTier: + return self.TIER + + def destroy(self) -> None: + if self._total_quota == 0: + return + for pg in self._pool_groups: + pg.destroy() + self._total_quota = 0 + self._ratio_list = () + + def allocate(self, pool_group_index: PoolGroupIndex) -> Slot: + return self._pool_groups[pool_group_index].allocate() + + def allocate_multiple(self, pool_group_index: PoolGroupIndex, num_slots: int) -> list[Slot]: + return self._pool_groups[pool_group_index].allocate_multiple(num_slots) + + def release(self, pool_group_index: PoolGroupIndex, slot: Slot) -> None: + self._pool_groups[pool_group_index].release(slot) + + @property + def total_quota(self) -> int: + return self._total_quota + + @property + def ratio_list(self) -> HomoTuple[float]: + return self._ratio_list + + def num_slots(self, pool_group_index: PoolGroupIndex) -> int: + return self._pool_groups[pool_group_index].num_slots + + def get_num_free_slots(self, pool_group_index: PoolGroupIndex) -> int: + return self._pool_groups[pool_group_index].num_free_slots + + @property + def slot_count_list(self) -> HomoTuple[int]: + """ + The number of slots in each pool group. + """ + return tuple(pg.num_slots for pg in self._pool_groups) + + def slot_size(self, pool_group_index: PoolGroupIndex) -> HomoTuple[int]: + """ + The slot sizes of each pool in the pool group. + """ + return self._pool_groups[pool_group_index].slot_size + + @property + def slot_size_lists(self) -> HomoTuple[HomoTuple[int]]: + """ + A tuple of tuples, each containing the slot sizes for a pool group. + """ + return tuple(tuple(p.slot_size for p in pg._pools) for pg in self._pool_groups) + + @property + def num_pool_groups(self) -> PoolGroupIndex: + return PoolGroupIndex(len(self._pool_groups)) + + def slot_address( + self, pool_group_index: PoolGroupIndex, pool_index: PoolIndex, slot_id: SlotId + ) -> Address: + return self._pool(pool_group_index, pool_index).slot_address(slot_id) + + def resize( + self, new_total_quota: int | None = None, new_ratio_list: Sequence[float] | None = None + ) -> None: + new_slot_count_list = self._compute_slot_count_list(new_total_quota, new_ratio_list) + self._resize_impl(new_slot_count_list) + if new_total_quota is not None: + self._total_quota = new_total_quota + if new_ratio_list is not None: + self._ratio_list = tuple(new_ratio_list) + + def _resize_impl(self, new_slot_count_list: Sequence[int]) -> None: + old_slot_count_list = self.slot_count_list + assert old_slot_count_list == self._compute_slot_count_list( + self.total_quota, self.ratio_list + ) + try: + # shrink first to avoid intermediate state with excessive memory usage + for pg, new_slot_count, old_slot_count in zip( + self._pool_groups, new_slot_count_list, old_slot_count_list + ): + if new_slot_count < old_slot_count: + pg.resize_slot_allocator( + new_slot_count + ) # shrink slot allocators first as it can fail for shrinking + pg.resize_pools(new_slot_count) + for pg, new_slot_count, old_slot_count in zip( + self._pool_groups, new_slot_count_list, old_slot_count_list + ): + if new_slot_count > old_slot_count: + pg.resize_pools( + new_slot_count + ) # expand pools first as it can fail for expanding + pg.resize_slot_allocator(new_slot_count) + except Exception: + self._resize_impl(old_slot_count_list) + raise + + def _pool(self, pool_group_index: PoolGroupIndex, pool_index: PoolIndex) -> SlotPoolBase: + return self._pool_groups[pool_group_index]._pools[pool_index] + + # Calculate how many slots will there be in each pool group with the given total_quota and + # ratio_list. Use _ratio_to_slot_count_list for initialization. + def _compute_slot_count_list( + self, total_quota: int | None = None, ratio_list: Sequence[float] | None = None + ) -> HomoTuple[int]: + if total_quota is None: + total_quota = self.total_quota + if ratio_list is None: + ratio_list = self.ratio_list + assert len(ratio_list) == len(self._pool_groups), ( + f"Wrong ratio_list length. Expected {len(self._pool_groups)}, got {len(ratio_list)}" + ) + return self._ratio_to_slot_count_list( + total_quota, self.slot_size_lists, ratio_list, self.pool_size_granularity + ) + + @staticmethod + def _ratio_to_slot_count_list( + total_quota: int, + slot_size_lists: Sequence[Sequence[int]], + ratio_list: Sequence[float], + pool_size_granularity: int, + ) -> HomoTuple[int]: + num_pool_groups = len(ratio_list) + assert num_pool_groups == len(slot_size_lists) + assert total_quota % pool_size_granularity == 0 + total_grains = total_quota // pool_size_granularity + assert total_grains >= sum(len(sizes) for sizes in slot_size_lists) + remaining_grains = total_grains + granularity = pool_size_granularity + slot_cnt_list = [0] * num_pool_groups + # divide total_quota into pool groups based on init_ratio, then divide quote for each pool_group + # into pools based on slot_size. + pg_idx_lst = sorted(range(len(ratio_list)), key=lambda i: ratio_list[i]) + for i, pg in enumerate(pg_idx_lst): + slot_size_list = slot_size_lists[pg] + min_pool_grains = [div_up(s, granularity) for s in slot_size_list] + pct: float = ratio_list[pg] / sum(ratio_list[j] for j in pg_idx_lst[i:]) + pg_grains = max(round(remaining_grains * pct), sum(min_pool_grains)) + num_slots: int = 1 << 63 + remaining_pg_grains = pg_grains + pool_idx_lst = sorted(range(len(slot_size_list)), key=lambda i: slot_size_list[i]) + for j, pool in enumerate(pool_idx_lst): + slot_size = slot_size_list[pool] + pool_grains = max( + min_pool_grains[pool], + round( + remaining_pg_grains + * (slot_size / sum(slot_size_list[k] for k in pool_idx_lst[j:])) + ), + ) + num_slots = min(num_slots, pool_grains * granularity // slot_size) + remaining_pg_grains -= pool_grains + assert remaining_pg_grains == 0 + assert num_slots > 0 + slot_cnt_list[pg] = num_slots + remaining_grains -= pg_grains + assert remaining_grains == 0 + return tuple(slot_cnt_list) + + @property + def pool_size_granularity(self) -> int: + return 2 << 20 + + +class GpuCacheLevelStorage(CacheLevelStorage): + TIER: ClassVar[CacheTier] = CacheTier.GPU_MEM + __slots__ = ("shared_phys_mem_pool",) + shared_phys_mem_pool: PooledPhysMemAllocator + + def __init__( + self, + total_quota: int, + slot_size_lists: Sequence[Sequence[int]], + init_ratio: Sequence[float], + phys_mem_size: int, + ): + assert len(slot_size_lists) == len(init_ratio), ( + "slot_size_lists and init_ratio must have the same length" + ) + super().__init__(total_quota, init_ratio) + slot_count_list = self._ratio_to_slot_count_list( + total_quota, slot_size_lists, init_ratio, phys_mem_size + ) + self.shared_phys_mem_pool = PooledPhysMemAllocator(phys_mem_size) + self._pool_groups = tuple( + GpuPoolGroup(num_slots, slot_size_list, self.shared_phys_mem_pool) + for slot_size_list, num_slots in zip(slot_size_lists, slot_count_list) + ) + + @override + def resize( + self, new_total_quota: int | None = None, new_ratio_list: Sequence[float] | None = None + ): + super().resize(new_total_quota, new_ratio_list) + self.shared_phys_mem_pool.clear() # clear cached unused phys mem + + @property + def pool_size_granularity(self) -> int: + return self.shared_phys_mem_pool.phys_mem_size + + +class HostCacheLevelStorage(CacheLevelStorage): + TIER: ClassVar[CacheTier] = CacheTier.HOST_MEM + POOL_SIZE_GRANULARITY: ClassVar[int] = HostMem.ALIGNMENT + __slots__ = () + + def __init__( + self, + total_quota: int, + slot_size_lists: Sequence[Sequence[int]], + init_ratio: Sequence[float], + ): + super().__init__(total_quota, init_ratio) + slot_count_list = self._ratio_to_slot_count_list( + total_quota, slot_size_lists, init_ratio, self.pool_size_granularity + ) + self._pool_groups = tuple( + HostPoolGroup(num_slots, slot_size_list) + for slot_size_list, num_slots in zip(slot_size_lists, slot_count_list) + ) + + @property + def pool_size_granularity(self) -> int: + return self.POOL_SIZE_GRANULARITY + + +class DiskCacheLevelStorage(CacheLevelStorage): + __slots__ = () + TIER: ClassVar[CacheTier] = CacheTier.DISK + POOL_SIZE_GRANULARITY: ClassVar[int] = 2 << 20 + + def __init__( + self, + total_quota: int, + slot_size_lists: Sequence[Sequence[int]], + init_ratio: Sequence[float], + filename_template: str, + ): + super().__init__(total_quota, init_ratio) + slot_count_list = self._ratio_to_slot_count_list( + total_quota, slot_size_lists, init_ratio, self.pool_size_granularity + ) + self._pool_groups = tuple( + DiskPoolGroup(num_slots, slot_size_list, filename_template.format(pg_idx, "{}")) + for pg_idx, (slot_size_list, num_slots) in enumerate( + zip(slot_size_lists, slot_count_list) + ) + ) + + @property + def pool_size_granularity(self) -> int: + return self.POOL_SIZE_GRANULARITY diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_storage_manager.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage_manager.py new file mode 100644 index 0000000000..644f2bcbd2 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_storage_manager.py @@ -0,0 +1,518 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Iterator, Sequence, cast + +from . import rawref +from ._common import ( + GPU_LEVEL, + NDEBUG, + Address, + CacheLevel, + CacheTier, + LayerId, + MemAddress, + PageIndex, + PageStatus, +) +from ._config import CacheTierConfig, DataRole, DiskCacheTierConfig +from ._copy_engine import CopyTask, batched_copy +from ._eviction_controller import EvictablePage, PerLevelEvictionController +from ._exceptions import OutOfPagesError +from ._life_cycle_registry import LifeCycleId, LifeCycleRegistry +from ._page import Page +from ._storage import CacheLevelStorage +from ._storage._config import BufferAttr, BufferId, StorageConfig +from ._storage._core import ( + DiskCacheLevelStorage, + GpuCacheLevelStorage, + HostCacheLevelStorage, + PoolGroupBase, + PoolGroupIndex, + PoolIndex, + Slot, + SlotId, +) +from ._utils import ( + Array2D, + CachedCudaEvent, + HomoTuple, + TemporaryCudaStream, + TypedIndexList, + filled_array2d, + filled_list, + get_uniform_attribute, + make_typed, + map_optional, + partition, + remove_if, + round_up, + typed_enumerate, + typed_range, +) + + +class CacheLevelManager: + __slots__ = ("cache_level", "storage", "controller") + cache_level: CacheLevel + storage: CacheLevelStorage + controller: PerLevelEvictionController + + @property + def cache_tier(self) -> CacheTier: + return self.storage.cache_tier + + def __init__( + self, + life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex], + cache_level: CacheLevel, + config: CacheTierConfig, + slot_size_lists: Sequence[Sequence[int]], + init_ratio: Sequence[float], + ): + self.cache_level = cache_level + self.storage = self._create_cache_level_storage(config, slot_size_lists, init_ratio) + self.controller = PerLevelEvictionController(life_cycle_grouping, cache_level) + + @property + def num_pool_groups(self) -> PoolGroupIndex: + assert self.storage.num_pool_groups == self.controller.num_pool_groups + return self.storage.num_pool_groups + + @staticmethod + def _create_cache_level_storage( + config: CacheTierConfig, + slot_size_lists: Sequence[Sequence[int]], + init_ratio: Sequence[float], + ) -> CacheLevelStorage: + quota = config.quota + num_pools = sum(len(sizes) for sizes in slot_size_lists) + + def adjust_quota(quota: int, granularity: int) -> int: + return max(granularity * num_pools, round_up(quota, granularity)) + + if config.tier == CacheTier.GPU_MEM: + page_size = 2 << 20 + phys_mem_size = page_size << min(4, max(0, int(math.log(quota / (page_size * 512), 2)))) + quota = adjust_quota(quota, phys_mem_size) + return GpuCacheLevelStorage(quota, slot_size_lists, init_ratio, phys_mem_size) + elif config.tier == CacheTier.HOST_MEM: + quota = adjust_quota(quota, HostCacheLevelStorage.POOL_SIZE_GRANULARITY) + return HostCacheLevelStorage(quota, slot_size_lists, init_ratio) + elif config.tier == CacheTier.DISK: + assert isinstance(config, DiskCacheTierConfig) + assert os.path.isdir(config.path), ( + f"Disk path {config.path} does not exist or is not a directory" + ) + quota = adjust_quota(quota, DiskCacheLevelStorage.POOL_SIZE_GRANULARITY) + filename_template = os.path.join(config.path, "g{}p{}.bin") + return DiskCacheLevelStorage(quota, slot_size_lists, init_ratio, filename_template) + else: + raise ValueError(f"Invalid cache tier: {config.tier}") + + +@dataclass(slots=True, frozen=True) +class StorageStatistics: + "All in number of slots, for one pool group" + + slot_size: HomoTuple[int] + total: int + free: int + evictable: int + + @property + def available(self) -> int: + return self.free + self.evictable + + @property + def unavailable(self) -> int: + return self.total - self.available + + +class StorageManager: + __slots__ = ( + "_life_cycles", + "_layer_to_life_cycle_ids", + "_slot_to_page_indices", + "_buffer_attr", + "_life_cycle_grouping", + "_levels", + "_cached_num_pool_groups", + "__rawref__", + ) + _life_cycles: LifeCycleRegistry + _layer_to_life_cycle_ids: TypedIndexList[LayerId, LifeCycleId] + _slot_to_page_indices: TypedIndexList[LifeCycleId, int] + _buffer_attr: dict[BufferId, BufferAttr] + _life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex] + _levels: TypedIndexList[CacheLevel, CacheLevelManager] + _cached_num_pool_groups: PoolGroupIndex + __rawref__: rawref.ref["StorageManager"] + + def __init__(self, life_cycles: LifeCycleRegistry, config: StorageConfig) -> None: + self.__rawref__ = rawref.NULL + assert config.cache_tiers[GPU_LEVEL].tier == CacheTier.GPU_MEM, ( + "The first cache tier must be GPU memory" + ) + self._life_cycles = life_cycles + self._layer_to_life_cycle_ids = config.layer_to_life_cycle_ids() + self._slot_to_page_indices = config.slot_to_page_indices() + self._buffer_attr = config.buffer_attributes() + self._life_cycle_grouping = config.life_cycle_grouping() + slot_size_lists = [pg.slot_size_list for pg in config.pool_groups] + # @TODO: accept an optional avg_seq_len param and consider sliding window. + init_ratio = [float(sum(pg.slot_size_list) * len(pg.slots)) for pg in config.pool_groups] + total = sum(init_ratio) + init_ratio = [x / total for x in init_ratio] + num_levels = CacheLevel(len(config.cache_tiers)) + self._levels = cast( + TypedIndexList, + [ + CacheLevelManager( + self._life_cycle_grouping, i, config.cache_tiers[i], slot_size_lists, init_ratio + ) + for i in typed_range(num_levels) + ], + ) + self._cached_num_pool_groups = get_uniform_attribute( + self._levels, lambda level: level.storage.num_pool_groups + ) + + def __del__(self) -> None: + self.__rawref__.invalidate() + + def get_pool_group_index(self, life_cycle: LifeCycleId) -> PoolGroupIndex: + return self._life_cycle_grouping[life_cycle] + + def new_gpu_slots( + self, num_slots: TypedIndexList[LifeCycleId, int] + ) -> TypedIndexList[LifeCycleId, list[Slot]]: + return self.new_slots(GPU_LEVEL, num_slots) + + def new_slots( + self, level: CacheLevel, num_slots: TypedIndexList[LifeCycleId, int] + ) -> TypedIndexList[LifeCycleId, list[Slot]]: + pg_num_slots = filled_list(0, self.num_pool_groups) + for lc in typed_range(self.num_life_cycles): + pg_num_slots[self.get_pool_group_index(lc)] += num_slots[lc] + storage = self._levels[level].storage + if any( + pg_num_slots[pg] > storage.get_num_free_slots(pg) + for pg in typed_range(self.num_pool_groups) + ): + self.prepare_free_slots(level, pg_num_slots) + assert all( + pg_num_slots[pg] <= storage.get_num_free_slots(pg) + for pg in typed_range(self.num_pool_groups) + ) + ret = filled_list(list[Slot](), self.num_life_cycles) + try: + for life_cycle in typed_range(self.num_life_cycles): + pg_idx = self.get_pool_group_index(life_cycle) + ret[life_cycle] = storage.allocate_multiple(pg_idx, num_slots[life_cycle]) + except Exception: + warnings.warn("Exception not expected here. Please report a bug.") + for lc, slots in typed_enumerate(ret): + pg_idx = self.get_pool_group_index(lc) + for s in slots: + storage.release(pg_idx, s) + raise + return ret + + @property + def life_cycles(self) -> LifeCycleRegistry: + return self._life_cycles + + @property + def num_life_cycles(self) -> LifeCycleId: + return LifeCycleId(len(self._life_cycle_grouping)) + + @property + def num_pool_groups(self) -> PoolGroupIndex: + return self._cached_num_pool_groups + + @property + def num_cache_levels(self) -> CacheLevel: + return CacheLevel(len(self._levels)) + + def is_last_level(self, level: CacheLevel) -> bool: + return level == self.num_cache_levels - 1 + + @property + def cache_tiers(self) -> HomoTuple[CacheTier]: + return tuple(cache_level.cache_tier for cache_level in self._levels) + + def is_evictable(self, page: EvictablePage, level: CacheLevel | None = None) -> bool: + """ + Check if a page is evictable. If level is specified, check if the page will be evictable after + migrating to the given level. + """ + status = page.status + level = page.cache_level if level is None else level + # droppable pages that are not committed should be dropped immediately. + # held pages in last level cache can't be evicted. + return (status == PageStatus.DROPPABLE and page.is_committed()) or ( + status == PageStatus.HELD and level < self.num_cache_levels - 1 + ) + + def prepare_free_slots( + self, level: CacheLevel, requirements: TypedIndexList[PoolGroupIndex, int] + ) -> None: + goals = filled_array2d(self.num_cache_levels, self.num_pool_groups, 0) + for pg in typed_range(self.num_pool_groups): + goals[level, pg] = requirements[pg] + fallen_pages = make_typed(lambda: list[Page](), self.num_pool_groups) + self._prepare_free_slots(goals, level, fallen_pages) + + def _prepare_free_slots( + self, + goals: Array2D[CacheLevel, PoolGroupIndex, int], + lvl_id: CacheLevel, + fallen_pages: TypedIndexList[PoolGroupIndex, list[Page]], + ) -> None: + assert NDEBUG or goals.rows == self.num_cache_levels and goals.cols == self.num_pool_groups + assert NDEBUG or all( + all(p.cache_level < lvl_id for p in pages) for pages in fallen_pages + ), "Fallen pages must come from upper cache levels" + storage = self._levels[lvl_id].storage + ctrl = self._levels[lvl_id].controller + num_to_evict = filled_list(0, self.num_pool_groups) + held_pages = make_typed(lambda: list[Page](), self.num_pool_groups) + for pg_idx in typed_range(self.num_pool_groups): + goal = goals[lvl_id, pg_idx] + fallen = len(fallen_pages[pg_idx]) + old_free_cnt = storage.get_num_free_slots(pg_idx) + evictable_cnt = ctrl.num_evictable_pages(pg_idx) + num_to_evict[pg_idx] = max(0, min(goal + fallen - old_free_cnt, evictable_cnt)) + fallen_held_cnt = 0 # fallen held pages we must accept in the current level. + if self.is_last_level(lvl_id): + held_pages[pg_idx] = remove_if( + fallen_pages[pg_idx], lambda p: p.status == PageStatus.HELD + ) + fallen_held_cnt = len(held_pages[pg_idx]) + if fallen_held_cnt > old_free_cnt + evictable_cnt: + # Do we need to revert the eviction we did before? Maybe not. + raise OutOfPagesError( + "Too many held pages are being evicted to the last-level cache for group {pg_idx}" + ) + if old_free_cnt + evictable_cnt - fallen_held_cnt < goal: + raise OutOfPagesError( + "Impossible to meet the goal ({goal} free slots) for group {pg_idx}" + ) + evicted = ctrl.evict(num_to_evict) + accepted_pages = make_typed(lambda: list[Page](), self.num_pool_groups) + is_last_level = self.is_last_level(lvl_id) + if is_last_level: + for pg_idx in typed_range(self.num_pool_groups): + old_free_cnt = storage.get_num_free_slots(pg_idx) + num_evicted = len(evicted[pg_idx]) + assert NDEBUG or all(p.status == PageStatus.DROPPABLE for p in evicted[pg_idx]) + if not NDEBUG: + dbg_rawrefs = [rawref.ref(p) for p in evicted[pg_idx]] + evicted[pg_idx].clear() + if not NDEBUG: + assert all(p() is None for p in dbg_rawrefs) # pyright: ignore + new_free_cnt = storage.get_num_free_slots(pg_idx) + # GC of some pages may trigger removal of radix tree blocks and some other pages. + assert new_free_cnt >= num_evicted + old_free_cnt + assert len(held_pages[pg_idx]) <= new_free_cnt + fallen_pages[pg_idx].extend(held_pages[pg_idx]) + held_pages[pg_idx].clear() + goal = goals[lvl_id, pg_idx] + num_accepted = min(new_free_cnt - goal, len(fallen_pages[pg_idx])) + assert num_accepted >= 0 + accepted_pages[pg_idx] = ( + fallen_pages[pg_idx][-num_accepted:] if num_accepted > 0 else [] + ) + fallen_pages[pg_idx].clear() + else: + assert all(len(g) == 0 for g in held_pages) + for pg_idx in typed_range(self.num_pool_groups): + old_free_cnt = storage.get_num_free_slots(pg_idx) + e = evicted[pg_idx] + num_evicted = len(e) + fallen_pages[pg_idx][:0] = cast(list[Page], e) + e.clear() + num_accepted = min( + old_free_cnt + num_evicted - goals[lvl_id, pg_idx], len(fallen_pages[pg_idx]) + ) + assert num_accepted >= 0 + if num_accepted > 0: + accepted_pages[pg_idx] = fallen_pages[pg_idx][-num_accepted:] + del fallen_pages[pg_idx][-num_accepted:] + self._prepare_free_slots(goals, CacheLevel(lvl_id + 1), fallen_pages) + assert all(len(f) == 0 for f in fallen_pages) + # migrate pages + for pg_idx in typed_range(self.num_pool_groups): + partitioned = partition( + accepted_pages[pg_idx], + lambda p: (p.cache_level, self.get_pool_group_index(p.life_cycle)), + ) + accepted_pages[pg_idx].clear() + for (src_lvl, pg_idx), pages in partitioned.items(): + dst_lvl = lvl_id + self._batched_migrate(pg_idx, dst_lvl, src_lvl, pages, update_src=True) + for p in pages: + if is_last_level and p.status == PageStatus.HELD: + continue + self._levels[dst_lvl].controller.schedule_for_eviction(p) + return + + def _batched_migrate( + self, + pool_group_index: PoolGroupIndex, + dst_level: CacheLevel, + src_level: CacheLevel, + src_pages: Sequence[Page], + update_src: bool, + ) -> Sequence[Slot] | None: + "Free slots must be prepared before calling this function." + assert dst_level != src_level, "dst_level and src_level must be different" + num_slots = len(src_pages) + num_pools = self.num_pools(pool_group_index) + src_pool_group = self._pool_group(src_level, pool_group_index) + dst_pool_group = self._pool_group(dst_level, pool_group_index) + if dst_pool_group.num_free_slots < num_slots: + raise OutOfPagesError("Not enough free slots") + dst_slots = dst_pool_group.allocate_multiple(num_slots) + try: + assert len(dst_slots) == num_slots + prior_events: set[CachedCudaEvent] = set() + tasks_per_pool: list[list[CopyTask]] = [[]] * num_pools + for src, dst in zip(src_pages, dst_slots): + assert src.node_ref is None + prior_events.update((dst.ready_event, src.ready_event)) + dst_addresses = dst_pool_group.slot_address(dst.slot_id) + src_addresses = src_pool_group.slot_address(src.slot_id) + for pool_idx in range(num_pools): + tasks_per_pool[pool_idx].append( + CopyTask(dst_addresses[pool_idx], src_addresses[pool_idx]) + ) + dst_tier = self._levels[dst_level].cache_tier + src_tier = self._levels[src_level].cache_tier + with TemporaryCudaStream(prior_events) as stream: + slot_sizes = self.slot_size(pool_group_index) + for pool_idx, tasks in enumerate(tasks_per_pool): + batched_copy(dst_tier, src_tier, slot_sizes[pool_idx], tasks, stream.get()) + finish_event = stream.take_finish_event() + for src, dst in zip(src_pages, dst_slots): + dst.ready_event = finish_event + src.ready_event = ( + finish_event # compulsory for the next owner getting this slot from the pool. + ) + if update_src: + scheduled_for_eviction = src.scheduled_for_eviction + if scheduled_for_eviction: + self.exclude_from_eviction(src) + src_pool_group.release(src) + src.set_slot(dst) + src.cache_level = dst_level + if scheduled_for_eviction: + self.schedule_for_eviction(src) + return None if update_src else dst_slots + except Exception: + for s in dst_slots: + dst_pool_group.release(s) + raise + + def _pool_group( + self, cache_level: CacheLevel, pool_group_index: PoolGroupIndex + ) -> PoolGroupBase: + return self._levels[cache_level].storage._pool_groups[pool_group_index] + + def num_pools(self, pool_group_index: PoolGroupIndex) -> PoolIndex: + return get_uniform_attribute( + self._levels, lambda level: level.storage._pool_groups[pool_group_index].num_pools + ) + + def slot_size(self, pool_group_index: PoolGroupIndex) -> HomoTuple[int]: + return get_uniform_attribute( + self._levels, lambda level: level.storage.slot_size(pool_group_index) + ) + + def num_slots( + self, pool_group_index: PoolGroupIndex, cache_level: CacheLevel = GPU_LEVEL + ) -> int: + return self._levels[cache_level].storage.num_slots(pool_group_index) + + def release_slot(self, life_cycle: LifeCycleId, cache_level: CacheLevel, slot: Slot) -> None: + pg_idx = self.get_pool_group_index(life_cycle) + self._levels[cache_level].storage.release(pg_idx, slot) + + def schedule_for_eviction(self, page: EvictablePage) -> None: + if self.is_evictable(page): + self._levels[page.cache_level].controller.schedule_for_eviction(page) + + def exclude_from_eviction(self, page: EvictablePage) -> None: + assert page.node_ref is not None + self._levels[page.cache_level].controller.remove(page.node_ref) + + def get_mem_pool_base_address(self, layer_id: LayerId, data_role: DataRole) -> MemAddress: + storage = self._levels[GPU_LEVEL].storage + attr = self.get_buffer_attr(layer_id, data_role) + pg_idx = self.get_pool_group_index(attr.life_cycle_id) + return MemAddress( + cast(int, storage.slot_address(pg_idx, attr.pool_index, SlotId(0))) + attr.offset + ) + + def get_page_indices_ref( + self, lc_id: LifeCycleId, pages: Iterator[Page | None] + ) -> Iterator[int | None]: + "Reference implementation. Not fast enough for production." + scale = self._slot_to_page_indices[lc_id] + return (map_optional(page, lambda p: scale * int(p.slot_id)) for page in pages) + + def get_buffer_attr(self, layer_id: LayerId, data_role: DataRole) -> BufferAttr: + return self._buffer_attr[BufferId(layer_id, data_role)] + + def slot_address( + self, level: CacheLevel, pg_idx: PoolGroupIndex, slot_id: SlotId, pool_idx: PoolIndex + ) -> Address: + return self._levels[level].storage.slot_address(pg_idx, pool_idx, slot_id) + + def get_page_indices_for_slot(self, life_cycle: LifeCycleId, slot_id: SlotId) -> PageIndex: + scale = self._slot_to_page_indices[life_cycle] + return PageIndex(scale * slot_id) + + def get_statistics( + self, level: CacheLevel = GPU_LEVEL + ) -> TypedIndexList[PoolGroupIndex, StorageStatistics]: + ret = make_typed(lambda: StorageStatistics((), 0, 0, 0), self.num_pool_groups) + for pg_idx in typed_range(self.num_pool_groups): + pg = self._pool_group(level, pg_idx) + evictable_cnt = self._levels[level].controller.num_evictable_pages(pg_idx) + ret[pg_idx] = StorageStatistics( + pg.slot_size, pg.num_slots, pg.num_free_slots, evictable_cnt + ) + return ret + + def get_utilization( + self, level: CacheLevel = GPU_LEVEL + ) -> TypedIndexList[PoolGroupIndex, float]: + ret = make_typed(lambda: 0.0, self.num_pool_groups) + stats = self.get_statistics(level) + for pg_idx in typed_range(self.num_pool_groups): + ret[pg_idx] = stats[pg_idx].unavailable / stats[pg_idx].total + return ret + + def get_overall_utilization(self, level: CacheLevel = GPU_LEVEL) -> float: + stats = self.get_statistics(level) + return sum(sum(s.slot_size) * s.unavailable for s in stats) / sum( + sum(s.slot_size) * s.total for s in stats + ) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/_utils.py b/tensorrt_llm/runtime/kv_cache_manager_v2/_utils.py new file mode 100644 index 0000000000..c7ead1d715 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/_utils.py @@ -0,0 +1,945 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import array +import ctypes +import errno +import functools +import itertools +import operator +import os +import platform +import traceback +import warnings +import weakref +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from collections.abc import Set +from ctypes.util import find_library +from itertools import pairwise +from typing import ( + Any, + Callable, + ClassVar, + Final, + Generic, + Iterable, + Iterator, + MutableSequence, + Protocol, + Reversible, + Sequence, + Type, + TypeVar, + cast, +) + +import cuda.bindings.driver as drv +import cuda.bindings.runtime as cudart + +from . import rawref +from ._common import NDEBUG, CudaStream +from ._exceptions import CuError, CuOOMError, DiskOOMError, HostOOMError + +T = TypeVar("T") +U = TypeVar("U") +Index = TypeVar("Index", bound=int, contravariant=True) +IndexO = TypeVar("IndexO", bound=int, covariant=True) +Row = TypeVar("Row", bound=int) +Col = TypeVar("Col", bound=int) + + +def _unwrap( + ret: drv.CUresult + | tuple[ + drv.CUresult, + T, + ] + | tuple[drv.CUresult, T, U], +): + if isinstance(ret, drv.CUresult): + if int(ret) != int(drv.CUresult.CUDA_SUCCESS): # pyright: ignore + if int(ret) == int(drv.CUresult.CUDA_ERROR_OUT_OF_MEMORY): # pyright: ignore + raise CuOOMError() + raise CuError(ret) + else: + _unwrap(ret[0]) + return ret[1] if len(ret) == 2 else ret[1:] + + +def div_up(x: int, y: int) -> int: + return (x + y - 1) // y + + +def round_up(x: int, y: int) -> int: + return div_up(x, y) * y + + +def round_down(x: int, y: int) -> int: + return x // y * y + + +def in_range(x: int, lower: int, upper: int) -> bool: + return lower <= x < upper + + +def exact_div(x: int, y: int) -> int: + assert x % y == 0 + return x // y + + +def overlap(a: tuple[Index, Index], b: tuple[Index, Index]) -> tuple[Index, Index] | tuple[()]: + "Returns the overlap of two ranges, or an empty tuple if they do not overlap." + return (max(a[0], b[0]), min(a[1], b[1])) if a[0] < b[1] and b[0] < a[1] else () + + +def value_or(opt: T | None, default: T) -> T: + return default if opt is None else opt + + +def unwrap_optional(value: T | None) -> T: + if value is not None: + return value + raise ValueError("Expected non-None value") + + +def unwrap_weakref(ref: weakref.ref[T]) -> T: + obj = ref() + if obj is not None: + return obj + raise ValueError("Dereferencing a dangling weakref") + + +def unwrap_rawref(ref: rawref.ref[T]) -> T: + obj = ref() + if obj is not None: + return obj + raise ValueError("Dereferencing a dangling rawref") + + +def map_optional(value: T | None, func: Callable[[T], U]) -> U | None: + return func(value) if value is not None else None + + +def remove_if(original: MutableSequence[T], predicate: Callable[[T], bool]) -> list[T]: + "Remove items from original that satisfy the predicate and return the removed items." + removed = [] + for idx, item in enumerate(original): + if predicate(item): + removed.append(item) + else: + original[idx - len(removed)] = item + del original[len(original) - len(removed) :] + return removed + + +def chunked(iterable: Iterable[T], size: int) -> Iterator[list[T]]: + iterator = iter(iterable) + while True: + chunk = list(itertools.islice(iterator, size)) + if not chunk: + break + yield chunk + + +def partition(original: Iterable[T], classifier: Callable[[T], U]) -> defaultdict[U, list[T]]: + ret = defaultdict(list) + for item in original: + ret[classifier(item)].append(item) + return ret + + +def get_uniform_attribute(iterable: Iterable[T], attribute_func: Callable[[T], U]) -> U: + ret = attribute_func(next(iter(iterable))) + assert NDEBUG or all(attribute_func(item) == ret for item in iterable) + return ret + + +def assert_critical(condition: bool, message: str | None = None) -> None: + "Similar to assert, but instead of raising an exception, it terminates the process, even if inside __del__()." + if not condition: + warnings.warn(value_or(message, "Critical assertion failed")) + traceback.print_stack() + os._exit(1) + + +def noexcept(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + def wrapper(*args, **kwargs) -> T: + try: + return func(*args, **kwargs) + except Exception as e: + raise AssertionError(f"Function {func.__name__} raised an exception: {e}") from e + + return wrapper + + +def not_implemented(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + def wrapper(*args, **kwargs) -> T: + raise NotImplementedError(f"The function '{func.__name__}' is not implemented yet.") + + return wrapper + + +def expect_type(ExpectedType: Type[T], value: Any) -> T: + "Similar to typing.cast, but does runtime checking with assert." + assert isinstance(value, ExpectedType), f"Expected {ExpectedType}, got {type(value)}" + return value + + +def is_sorted( + iterable: Iterable[T], key: Callable[[T], Any] = lambda x: x, reverse: bool = False +) -> bool: + comp = operator.ge if reverse else operator.le + return all(comp(key(a), key(b)) for a, b in pairwise(iterable)) + + +HomoTuple = tuple[T, ...] + + +class TypedIndexList(Protocol[Index, T]): + """ + A protocol representing a list-like container with a strongly typed integer index. + Useful for enforcing index types like NewType wrappers around int. + """ + + def __getitem__(self, index: Index) -> T: ... + + def __setitem__(self, index: Index, value: T) -> None: ... + + def __delitem__(self, index: Index | slice) -> None: ... + + def __iter__(self) -> Iterator[T]: ... + + def __len__(self) -> int: ... + + def __reversed__(self) -> Iterator[T]: ... + + def clear(self) -> None: ... + + def pop(self) -> T: ... + + def append(self, value: T) -> None: ... + + +# @TODO: use this where applicable. +def to_typed(index_type: Type[Index], lst: list[T]) -> TypedIndexList[Index, T]: + """ + Casts a standard list to a TypedIndexList with a strongly typed integer index. + + Parameters: + index_type: A type alias for the NewType index, e.g. type(BlockOrdinal(0)) or a concrete class derived from int. + lst: The list to cast + + Returns: + A TypedIndexList[Index, T] with the specified index type + """ + return cast(TypedIndexList[Index, T], lst) + + +def filled_list(value: T, count: Index) -> TypedIndexList[Index, T]: + "Note that all elements will be the same value. Do not use mutable values." + return cast(TypedIndexList[Index, T], [value] * int(count)) + + +def make_typed(generator: Callable[[], T], count: Index) -> TypedIndexList[Index, T]: + return cast(TypedIndexList[Index, T], [generator() for _ in range(int(count))]) + + +def typed_len(iterable: TypedIndexList[IndexO, T]) -> IndexO: + return cast(IndexO, len(iterable)) + + +def typed_enumerate(iterable: TypedIndexList[Index, T]) -> Iterator[tuple[Index, T]]: + return cast(Iterator[tuple[Index, T]], enumerate(iterable)) + + +def typed_map( + iterable: TypedIndexList[Index, T], func: Callable[[T], U] +) -> TypedIndexList[Index, U]: + return cast(TypedIndexList[Index, U], [func(item) for item in iterable]) + + +class Array2D(Generic[Row, Col, T]): + __slots__ = ("_data", "_cols") + _data: list[T] + _cols: int + + def __init__(self, rows: Row, cols: Col, init_val: Iterable[T]) -> None: + self._data = list(init_val) + self._cols = cols + + def __getitem__(self, index: tuple[Row, Col]) -> T: + return self._data[index[0] * self._cols + index[1]] + + def __setitem__(self, index: tuple[Row, Col], value: T) -> None: + self._data[index[0] * self._cols + index[1]] = value + + @property + def rows(self) -> int: + assert len(self._data) % self._cols == 0 + return len(self._data) // self._cols + + def row(self, row: Row) -> TypedIndexList[Col, T]: + return cast(TypedIndexList[Col, T], self._data[row * self._cols : (row + 1) * self._cols]) + + def col(self, col: Col) -> TypedIndexList[Row, T]: + return cast(TypedIndexList[Row, T], self._data[col :: self._cols]) + + @property + def cols(self) -> int: + return self._cols + + def __len__(self) -> int: + return len(self._data) + + def __iter__(self) -> Iterator[T]: + return iter(self._data) + + def __reversed__(self) -> Iterator[T]: + return reversed(self._data) + + +def filled_array2d(rows: Row, cols: Col, val: T) -> Array2D[Row, Col, T]: + return Array2D(rows, cols, [val] * rows * cols) + + +def typed_range(*args: Index) -> Reversible[Index]: + return cast(Reversible[Index], range(*args)) + + +def find(seq: Sequence[T], predicate: Callable[[T], bool], default: U) -> T | U: + return next((item for item in seq if predicate(item)), default) + + +def find_index(seq: Iterable[T], predicate: Callable[[T], bool]) -> int: + i = 0 + for i, item in enumerate(seq): + if predicate(item): + return i + return i + 1 + + +mem_alignment: Final[int] = 2 << 20 # 2MB + +_libc = ctypes.CDLL(find_library("c")) +_libc.aligned_alloc.restype = ctypes.c_void_p +_libc.aligned_alloc.argtypes = [ctypes.c_size_t, ctypes.c_size_t] +_libc.madvise.restype = ctypes.c_int +_libc.madvise.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int] +_libc.realloc.restype = ctypes.c_void_p +_libc.realloc.argtypes = [ctypes.c_void_p, ctypes.c_size_t] +_libc.free.restype = None +_libc.free.argtypes = [ctypes.c_void_p] +_libc.posix_fallocate.restype = ctypes.c_int +_libc.posix_fallocate.argtypes = [ctypes.c_int, ctypes.c_longlong, ctypes.c_longlong] + + +def _aligned_alloc(alignment: int, size: int) -> int: + """ + Allocates size bytes of uninitialized storage whose alignment is specified by alignment. + Returns the address as an integer. + Raises HostOOMError on failure. + """ + assert size % alignment == 0 + memptr: ctypes.c_void_p = _libc.aligned_alloc(ctypes.c_size_t(alignment), ctypes.c_size_t(size)) + if memptr == ctypes.c_void_p(0): + raise HostOOMError("aligned_alloc failed") + return int(memptr) + + +def _madvise(ptr: int, size: int, advice: int) -> None: + if os.name == "nt": + return + ret = _libc.madvise(ctypes.c_void_p(ptr), ctypes.c_size_t(size), ctypes.c_int(advice)) + if ret != 0: + error_code = ctypes.get_errno() + error_msg = f"madvise failed with errno {error_code}: {errno.errorcode.get(error_code, 'Unknown error')}" + raise HostOOMError(error_msg) + + +MADV_HUGEPAGE: Final[int] = 14 + + +def _realloc(ptr: int, size: int) -> int: + """ + Reallocates size bytes of storage whose alignment is specified by alignment. + Returns the address as an integer. + Raises OSError on failure. + """ + ret = _libc.realloc(ctypes.c_void_p(ptr), ctypes.c_size_t(size)) + if ret == ctypes.c_void_p(0): + raise HostOOMError("realloc failed.") + return int(ret) + + +def _free(ptr: int) -> None: + _libc.free(ctypes.c_void_p(ptr)) + + +def _posix_fallocate(fd: int, offset: int, length: int) -> None: + ret = _libc.posix_fallocate( + ctypes.c_int(fd), ctypes.c_longlong(offset), ctypes.c_longlong(length) + ) + if ret != 0: + raise DiskOOMError(ret, "posix_fallocate failed") + + +class HostMem: + ALIGNMENT: ClassVar[int] = 2 << 20 + """ + Host memory aligned to 2MB, reallocable for low-cost resizing and registered to CUDA as page-locked memory. + Resizing will keep the original memory content, like `realloc` in C. + """ + __slots__ = ("_address", "_size", "_num_registered_chunks") + _address: int + _size: int + # If True and _size > 2GB, use multiple chunks to register pinned memory due to a Linux kernel + # 6.11/6.12/6.13 bug preventing pinning more than 2GB of host memory in one operation. + _CHUNKED_REGISTRATION: ClassVar[bool] = platform.system() == "Linux" and platform.release()[ + :4 + ] in ["6.11", "6.12", "6.13"] + _CHUNK_SIZE: ClassVar[int] = 2 << 30 + _num_registered_chunks: int + + @property + def address(self) -> int: + return self._address + + @property + def size(self) -> int: + return self._size + + def __init__(self, size: int) -> None: + self._num_registered_chunks = 0 + if size == 0: + self._address = 0 + self._size = 0 + return + self._address = _aligned_alloc(mem_alignment, size) + self._size = size + _madvise(self._address, size, MADV_HUGEPAGE) + self._register_to_cuda() + + def resize(self, new_size: int) -> None: + self._unregister_from_cuda() + try: + self._address = _realloc(self._address, new_size) + self._size = new_size + _madvise(self._address, new_size, MADV_HUGEPAGE) + finally: + self._register_to_cuda() + + def destroy(self) -> None: + if self._address == 0: + return + self._unregister_from_cuda() + _free(self._address) + self._address = 0 + self._size = 0 + + def __del__(self) -> None: + self.destroy() + + def _register_to_cuda(self) -> None: + assert self._num_registered_chunks == 0 + for addr, size in self._iterate_chunks(): + _unwrap( + drv.cuMemHostRegister( + addr, size, drv.CU_MEMHOSTREGISTER_PORTABLE | drv.CU_MEMHOSTREGISTER_DEVICEMAP + ) + ) + self._num_registered_chunks += 1 + + def _unregister_from_cuda(self) -> None: + for addr, _ in self._iterate_chunks(): + if self._num_registered_chunks == 0: + break + _unwrap(drv.cuMemHostUnregister(addr)) + self._num_registered_chunks -= 1 + assert self._num_registered_chunks == 0 + + def _iterate_chunks(self) -> Iterator[tuple[int, int]]: + start = self._address + end = start + self._size + chunk_size = self._CHUNK_SIZE if self._CHUNKED_REGISTRATION else self._size + for addr in range(start, end, chunk_size): + yield addr, min(end - addr, chunk_size) + + +def resize_file(fd: int, new_size: int) -> None: + old_size = os.lseek(fd, 0, os.SEEK_END) + if new_size > old_size: + _posix_fallocate(fd, old_size, new_size - old_size) + elif new_size < old_size: + os.truncate(fd, new_size) + + +class DynamicBitset: + """ + A memory efficient bitset that can be resized. + """ + + __slots__ = ("_bits", "_num_set_bits") + _bits: array.array + _num_set_bits: int + + TYPE_CODE: ClassVar[str] = "Q" + ALL_SET_MASK: ClassVar[int] = (1 << 64) - 1 + + def __init__(self, capacity: int) -> None: + self._bits = array.array(self.TYPE_CODE, [0] * (div_up(capacity, 64))) + self._num_set_bits = 0 + + def set(self, index: int) -> None: + if not self.get(index): + self._bits[index // 64] |= 1 << (index % 64) + self._num_set_bits += 1 + + def get(self, index: int) -> bool: + return self._bits[index // 64] & (1 << (index % 64)) != 0 + + def clear(self, index: int) -> None: + if self.get(index): + self._bits[index // 64] &= ~(1 << (index % 64)) + self._num_set_bits -= 1 + + @property + def num_set_bits(self) -> int: + return self._num_set_bits + + def resize(self, new_capacity: int) -> None: + extra_elems = div_up(new_capacity, 64) - len(self._bits) + if extra_elems > 0: + self._bits.extend(array.array(self.TYPE_CODE, [0] * extra_elems)) + elif extra_elems < 0: + self._bits = self._bits[:extra_elems] + if new_capacity % 64 != 0: + self._bits[-1] &= self.ALL_SET_MASK >> (64 - (new_capacity % 64)) + + # check if any bit in the range [start, end) is set + def any_set(self, start: int, end: int) -> bool: + if start >= end: + return False + start_word_mask = self.ALL_SET_MASK << (start % 64) + end_word_mask = self.ALL_SET_MASK >> (64 - (end % 64)) + if start // 64 == end // 64: + if (start_word_mask & end_word_mask & self._bits[start // 64]) != 0: + return True + else: + if (start_word_mask & self._bits[start // 64]) != 0 or ( + end % 64 != 0 and end_word_mask & self._bits[end // 64] + ) != 0: + return True + return any(self._bits[i] != 0 for i in range(start // 64 + 1, end // 64)) + + +@functools.cache +def init_cuda_once() -> None: + (err,) = cudart.cudaFree(0) + assert int(err) == int(cudart.cudaError_t.cudaSuccess) + + +class SimplePool(Generic[T]): + __slots__ = ( + "_create_func", + "_destroy_func", + "_init_size", + "_max_size", + "_outstanding_count", + "_items", + ) + _create_func: Callable[[], T] + _destroy_func: Callable[[T], None] + _init_size: int + _max_size: int | None + _items: deque[T] | None + _outstanding_count: ( + int # number of items currently we gave out but not returned, i.e. get() but not put() + ) + + def __init__( + self, + create_func: Callable[[], T], + destroy_func: Callable[[T], None], + init_size: int = 0, + max_size: int | None = None, + ): + self._create_func = create_func + self._destroy_func = destroy_func + self._init_size = init_size + self._max_size = max_size + self._items = None + self._outstanding_count = 0 + + def clear(self) -> None: + while self.items: + self._destroy_func(self.items.popleft()) + + def __del__(self) -> None: + self.clear() + + @property + def items(self) -> deque[T]: + if self._items is None: + self._items = deque[T]( + (self._create_func() for _ in range(self._init_size)), maxlen=self._max_size + ) + return self._items + + def get(self) -> T: + ret = self.items.popleft() if self.items else self._create_func() + self._outstanding_count += 1 + return ret + + def put(self, item: T) -> None: + self._outstanding_count -= 1 + if self._max_size is not None and len(self.items) >= self._max_size: + self._destroy_func(item) + else: + self.items.appendleft(item) + + @property + def outstanding_count(self) -> int: + "number of items currently we get() but not put()" + return self._outstanding_count + + @property + def cached_count(self) -> int: + "number of items currently in the pool" + return len(self.items) + + @property + def total_count(self) -> int: + "total number of items created, including both outstanding and cached" + return self.outstanding_count + self.cached_count + + +class ItemHolderBase(Generic[T], ABC): + __slots__ = ("_item",) + _item: T | None + + def __init__(self) -> None: + self._item = self.pool.get() + + def close(self) -> None: + # Manually inlined for better performance. + item = self._item + if item is not None: + self.pool.put(item) + self._item = None + + def __del__(self) -> None: + self.close() + + def is_closed(self) -> bool: + return self._item is None + + def get(self) -> T: + # Manually inlined for better performance. + item = self._item + assert item is not None + return item + + @property + def handle(self) -> T: + # Manually inlined for better performance. + item = self._item + assert item is not None + return item + + @property + @abstractmethod + def pool(self) -> SimplePool[T]: ... + + +class CachedCudaEvent(ItemHolderBase[drv.CUevent]): + """ + A cached CUDA event without support for timing. Recorded to a stream when created. + """ + + __slots__ = () + _pool: ClassVar[SimplePool[drv.CUevent] | None] = None + NULL: ClassVar["_NullCudaEvent"] + + def __init__(self, stream: CudaStream) -> None: + super().__init__() + self._record(stream) + + def query_complete(self) -> bool: + """ + Query the event. If complete, also close the event. Closed events are always considered complete. + """ + # Manually inlined for better performance. + ev = self._item + if ev is None: + return True + (err,) = drv.cuEventQuery(ev) + if int(err) == int(drv.CUresult.CUDA_SUCCESS): + self.close() + return True + elif int(err) == int(drv.CUresult.CUDA_ERROR_NOT_READY): + return False + else: + raise CuError(err) + + def synchronize(self) -> None: + # Manually inlined for better performance. + ev = self._item + if ev is None: + return + _unwrap(drv.cuEventSynchronize(ev)) + self.close() + + def wait_in_stream(self, stream: CudaStream) -> None: + # Manually inlined for better performance. + ev = self._item + if ev is None: + return + _unwrap(drv.cuStreamWaitEvent(stream, ev, 0)) + + def _record(self, stream: CudaStream) -> None: + """ + Prefer new event instead of recording an existing event. + """ + # Manually inlined for better performance. + ev = self._item + assert ev is not None + _unwrap(drv.cuEventRecord(ev, stream)) + + @property + def pool(self) -> SimplePool[drv.CUevent]: + if CachedCudaEvent._pool is None: + CachedCudaEvent._pool = SimplePool[drv.CUevent]( + lambda: _unwrap(drv.cuEventCreate(drv.CUevent_flags.CU_EVENT_DISABLE_TIMING)), + lambda ev: _unwrap(drv.cuEventDestroy(ev)), # pyright: ignore + init_size=1024, + ) + return CachedCudaEvent._pool + + +class _NullCudaEvent(CachedCudaEvent): + """ + A null CUDA event that is closed (and always complete). + """ + + __slots__ = () + + def __init__(self) -> None: + # do not call super().__init__(). We don't need an event here. + self._item = None + + +CachedCudaEvent.NULL = _NullCudaEvent() + + +# @TODO: consider do this in a single batch call to C++. +def stream_wait_events(stream: CudaStream, events: Iterable[CachedCudaEvent]) -> None: + "Batched wait for multiple events with deduplication first." + if not isinstance(events, Set): + events = set(events) + for ev in events: + ev.wait_in_stream(stream) + + +class CachedCudaStream(ItemHolderBase[CudaStream]): + """ + A cached non-blocking CUDA stream. + """ + + __slots__ = () + _pool: ClassVar[SimplePool[CudaStream] | None] = None + + def __init__(self) -> None: + super().__init__() + + def wait_event(self, event: drv.CUevent) -> None: + _unwrap(drv.cuStreamWaitEvent(self.get(), event, drv.CU_STREAM_WAIT_VALUE_COMPLETED)) + + def wait_events(self, events: Sequence[CachedCudaEvent] | set[CachedCudaEvent]) -> None: + """ + Wait for events with deduplication first. + """ + stream_wait_events(self.get(), events) + + def record_event(self) -> CachedCudaEvent: + return CachedCudaEvent(self.get()) + + def __cuda_stream__(self) -> tuple[int, int]: + return 0, int(self.get()) + + @property + def pool(self) -> SimplePool[CudaStream]: + if CachedCudaStream._pool is None: + CachedCudaStream._pool = SimplePool[CudaStream]( + lambda: CudaStream( + int(_unwrap(drv.cuStreamCreate(drv.CUstream_flags.CU_STREAM_NON_BLOCKING))) # pyright: ignore + ), + lambda stream: _unwrap(drv.cuStreamDestroy(stream)), # pyright: ignore + init_size=128, + ) + return CachedCudaStream._pool + + +class TemporaryCudaStream(CachedCudaStream): + """ + A cached non-blocking CUDA stream. Mainly used as temporary worker streams. + Requires a list of prior events to wait for dependencies. A finish event is recorded when exiting + normally. Call take_finish_event() to get the finish event. + """ + + __slots__ = "_finish_event" + _finish_event: CachedCudaEvent | None + + def __init__(self, prior_events: Sequence[CachedCudaEvent] | set[CachedCudaEvent]): + super().__init__() + self.wait_events(prior_events) + self._finish_event = None + + def __del__(self) -> None: + if self._finish_event is not None: + warnings.warn("[KVCacheManager] finish event recorded but not taken") + super().__del__() + + def take_finish_event(self) -> CachedCudaEvent: + ret = unwrap_optional(self._finish_event) + self._finish_event = None + return ret + + def __enter__(self) -> "TemporaryCudaStream": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + if not exc_type: + self._finish_event = self.record_event() + + +def merge_events(events: Sequence[CachedCudaEvent] | set[CachedCudaEvent]) -> CachedCudaEvent: + if len(events) == 0: + return CachedCudaEvent.NULL + if len(events) == 1: + ev = next(iter(events)) + return ev if not ev.is_closed() else CachedCudaEvent.NULL + with TemporaryCudaStream(events) as stream: + pass + return stream.take_finish_event() + + +class MultiStreamExecutor: + __slots__ = ("_prior_event", "_streams", "_finish_event") + _prior_event: CachedCudaEvent + _streams: list[TemporaryCudaStream] + _finish_event: CachedCudaEvent | None + + def __init__(self, prior_events: Sequence[CachedCudaEvent] | set[CachedCudaEvent]): + self._prior_event = merge_events(prior_events) + self._streams = [] + self._finish_event = None + + def __enter__(self) -> "MultiStreamExecutor": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + events = [s.take_finish_event() for s in self._streams] + self._streams.clear() + self._finish_event = merge_events(events) + + def __del__(self) -> None: + assert_critical(self._finish_event is None, "finish event not taken") + + def new_stream(self) -> TemporaryCudaStream: + stream = TemporaryCudaStream((self._prior_event,)) + self._streams.append(stream) + return stream + + def take_finish_event(self) -> CachedCudaEvent: + ret = unwrap_optional(self._finish_event) + self._finish_event = None + return ret + + +class SharedPoolProvider(Generic[T]): + _pool: SimplePool[T] + + def __init__(self, pool: SimplePool[T]): + self._pool = pool + + def pool(self) -> SimplePool[T]: + return self._pool + + +class ItemHolderWithSharedPool(ItemHolderBase[T]): + __slots__ = ("_pool",) + _pool: SimplePool[T] + + def __init__(self, pool: SimplePool[T]) -> None: + self._pool = pool + super().__init__() + + def __del__(self) -> None: + self.close() + + @property + def pool(self) -> SimplePool[T]: + return self._pool + + +HolderT = TypeVar("HolderT", bound=ItemHolderWithSharedPool) + + +# For subclassing if holder needs to be customized +class PooledFactoryBase(Generic[T, HolderT]): + _Holder: Type[HolderT] # subclasses must initialize this static attribute + __slots__ = ("_pool",) + _pool: SimplePool[T] + + def __init__( + self, + create_func: Callable[[], T], + destroy_func: Callable[[T], None], + init_size: int = 0, + max_cache_size: int | None = None, + ): + self._pool = SimplePool[T](create_func, destroy_func, init_size, max_cache_size) + + def create(self) -> HolderT: + return self._Holder(self._pool) + + def clear(self) -> None: + self._pool.clear() + + +def query_total_gpu_memory() -> int: + _, total = _unwrap(drv.cuMemGetInfo()) # pyright: ignore + return total + + +def query_free_gpu_memory() -> int: + free, _ = _unwrap(drv.cuMemGetInfo()) # pyright: ignore + return free + + +class CudaStreamWrapper: + "Just a wrapper to make it compatible with IsStreamT protocol. Does not own the stream." + + __slots__ = ("_stream",) + _stream: CudaStream + + def __init__(self, stream: CudaStream) -> None: + self._stream = stream + + def __cuda_stream__(self) -> tuple[int, int]: + return 0, int(self._stream) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/mypy_mypyc.ini b/tensorrt_llm/runtime/kv_cache_manager_v2/mypy_mypyc.ini new file mode 100644 index 0000000000..1b798b09ee --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/mypy_mypyc.ini @@ -0,0 +1,51 @@ +; SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +; SPDX-License-Identifier: Apache-2.0 +; +; Licensed under the Apache License, Version 2.0 (the "License"); +; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; http://www.apache.org/licenses/LICENSE-2.0 +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, +; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +; See the License for the specific language governing permissions and +; limitations under the License. + +[mypy] +# Only check files explicitly listed - don't follow any imports +follow_imports = skip +follow_imports_for_stubs = False + +# Ignore all missing imports +ignore_missing_imports = True + +# Allow untyped code in dependencies +allow_untyped_calls = True +allow_untyped_defs = True +check_untyped_defs = False + +# Disable various warnings to reduce noise +warn_return_any = False +warn_unused_ignores = False +warn_unreachable = False +no_implicit_optional = False + +# Don't check .pyi files outside our target +exclude = (?x)( + ^(?!tensorrt_llm/runtime/kv_cache_manager_v2/) +) + +# Ignore errors in any imported modules +[mypy-tensorrt_llm.executor.*] +ignore_errors = True +follow_imports = skip + +[mypy-tensorrt_llm.bindings.*] +ignore_errors = True +follow_imports = skip + +[mypy-torch.*] +ignore_errors = True +follow_imports = skip diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/README.md b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/README.md new file mode 100644 index 0000000000..6128f1a7b6 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/README.md @@ -0,0 +1,140 @@ + + +# rawref - Mutable Reference C Extension + +A C extension that provides a mutable reference class similar to `weakref.ref` for holding weak-like references to Python objects. + +## Features + +- **`ref[T]`**: A generic reference class (like `weakref.ref`) that stores an object's ID +- **Singleton pattern**: `ref(obj)` returns the same reference if `obj.__rawref__` is valid +- **Dereferencing**: Call `r()` to get the object, or `None` if invalid +- **Invalidation**: Call `r.invalidate()` to mark the reference as invalid +- **NULL constant**: Use `NULL` to initialize `__rawref__` attributes +- **Type-safe**: Comes with `.pyi` stub file for proper type checking +- **API compatible with weakref**: Use `ref` for both object creation and type hints + +## Building + +From the `rawref` directory: + +```bash +python setup.py build_ext --inplace +``` + +Or install it: + +```bash +pip install -e . +``` + +## Usage + +```python +from rawref import ref, NULL + +class MyClass: + # Class attribute: default value for __rawref__ + # Each instance will get its own __rawref__ instance attribute when ref() is called + __rawref__ = NULL + + def __init__(self, value): + self.value = value + + def __del__(self): + # self.__rawref__ is an instance attribute (set by ref()) + # Invalidate the canonical reference when object is destroyed + if self.__rawref__.is_valid: + self.__rawref__.invalidate() + +# Create an object and a reference to it (just like weakref.ref) +obj = MyClass(42) +r1 = ref(obj) + +# The reference is automatically stored as an instance attribute obj.__rawref__ +print(obj.__rawref__ is r1) # True + +# Singleton pattern: creating another ref returns the same one +r2 = ref(obj) +print(r1 is r2) # True + +# Dereference to get the object back +print(r1()) # +print(r1().value) # 42 + +# Check validity +print(r1.is_valid) # True + +# After invalidation +r1.invalidate() +print(r1()) # None +print(r1.is_valid) # False + +# Creating a new ref after invalidation creates a new reference +r3 = ref(obj) +print(r1 is r3) # False +print(r3.is_valid) # True +``` + +## Type Hints + +Like `weakref.ref`, you can use `ref` for both object creation and type hints: + +```python +from rawref import ref, NULL + +class MyClass: + __rawref__ = NULL + +# Create and type a reference +r: ref[MyClass] = ref(MyClass()) + +# Alternative: use ReferenceType directly +from rawref import ReferenceType +r: ReferenceType[MyClass] = ReferenceType(MyClass()) +``` + +## Warning + +This implementation uses raw object IDs (memory addresses) and attempts to dereference them. This is inherently unsafe and should be used with caution. The reference does not keep the object alive (unlike a strong reference), so care must be taken to ensure the object is not garbage collected while references exist. + +## API + +### Classes and Constants +- `ReferenceType`: The main reference class +- `ref`: Alias for `ReferenceType` (like `weakref.ref`) +- `NULL`: An invalid reference constant for initialization + +### Creation +- `ref(obj)`: Create a reference to `obj`, or return existing valid reference from `obj.__rawref__` + +### Properties +- `r.is_valid`: Check if the reference is still valid (read-only) + +### Methods +- `r()`: Dereference to get the object, or `None` if invalid +- `r.invalidate()`: Mark the reference as invalid + +## Singleton Pattern + +The `ref()` function implements a singleton pattern: +1. When `ref(obj)` is called, it checks if `obj.__rawref__` (instance attribute) exists and is valid +2. If yes, it returns the existing reference +3. If no, it creates a new reference and sets `obj.__rawref__` as an instance attribute + +**Note**: The class attribute `__rawref__ = NULL` is just a default value. When `ref(obj)` is called, it creates an **instance attribute** `obj.__rawref__` that shadows the class attribute. Each instance gets its own `__rawref__` instance attribute, ensuring each object has at most one canonical reference at a time. diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/__init__.py b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/__init__.py new file mode 100644 index 0000000000..672e53603a --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""rawref - Mutable reference with singleton pattern. + +This module provides a C extension for creating mutable references to Python +objects, similar to weakref.ref but with manual invalidation control and a +singleton pattern via __rawref__. + +The main purpose is to work around the issue that mypyc does not support +weakref. + +Main exports: +- ReferenceType: The reference class +- ref: Alias for ReferenceType (recommended, like weakref.ref) +- NULL: Invalid reference constant for initialization +""" + +from ._rawref import NULL, ReferenceType, ref + +__all__ = ["ReferenceType", "ref", "NULL"] + +__version__ = "2.0.0" diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/__init__.pyi b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/__init__.pyi new file mode 100644 index 0000000000..14d85a94c1 --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/__init__.pyi @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Generic, Optional, TypeVar + +T = TypeVar("T") + +class ReferenceType(Generic[T]): + """A mutable reference holder that stores an object ID. + + This class holds a reference to an object via its ID and allows + dereferencing it. The reference can be invalidated. + + Like weakref.ref, but stores raw object IDs instead of proper weak references. + + Implements a singleton pattern: calling ref(obj) multiple times returns the + same reference if obj.__rawref__ exists and is valid. + """ + + @property + def is_valid(self) -> bool: + """Check if the reference is still valid (read-only).""" + ... + + def __init__(self, obj: T) -> None: + """Initialize a ReferenceType with an object. + + If obj.__rawref__ exists and is valid, returns that instead. + Otherwise creates a new reference and sets obj.__rawref__ to it. + + Args: + obj: The object to reference. + """ + ... + + def __call__(self) -> Optional[T]: + """Dereference the object. + + Returns: + The referenced object, or None if the reference is invalid. + """ + ... + + def invalidate(self) -> None: + """Invalidate the reference. + + After calling this method, __call__() will return None. + This should be called from T.__del__ to invalidate the reference. + """ + ... + +# Alias 'ref' to 'ReferenceType' (like weakref.ref is an alias to weakref.ReferenceType) +ref = ReferenceType + +# NULL is an invalid reference constant that can be used to initialize __rawref__ +NULL: ReferenceType + +# For type hints, you can use either: +# r: ref[MyClass] = ref(obj) +# or: +# r: ReferenceType[MyClass] = ReferenceType(obj) +# Both are equivalent. diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/rawrefmodule.c b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/rawrefmodule.c new file mode 100644 index 0000000000..a6592a7a4c --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/rawrefmodule.c @@ -0,0 +1,238 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define PY_SSIZE_T_CLEAN +#include +#include + +/* ReferenceType object structure */ +typedef struct +{ + PyObject_HEAD Py_ssize_t object_id; /* ID of the referenced object */ + int valid; /* 1 if valid, 0 if invalidated */ +} ReferenceTypeObject; + +/* Forward declarations */ +static PyTypeObject ReferenceTypeType; + +/* Cached attribute name for faster lookups */ +static PyObject* rawref_attr_name = NULL; + +/* ReferenceType.__new__ - implements singleton pattern via __rawref__ */ +static PyObject* ReferenceType_new(PyTypeObject* type, PyObject* args, PyObject* kwds) +{ + PyObject* obj = NULL; + static char* kwlist[] = {"obj", NULL}; + + /* Parse arguments to get the object */ + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &obj)) + { + return NULL; + } + + /* Try to get existing __rawref__ using cached attribute name (faster) */ + PyObject* existing_ref = PyObject_GetAttr(obj, rawref_attr_name); + if (existing_ref != NULL) + { + /* Check if it's a ReferenceType instance and is valid */ + if (PyObject_TypeCheck(existing_ref, &ReferenceTypeType)) + { + ReferenceTypeObject* ref_obj = (ReferenceTypeObject*) existing_ref; + if (ref_obj->valid) + { + /* Return existing valid reference */ + return existing_ref; + } + } + Py_DECREF(existing_ref); + } + else + { + /* Clear the AttributeError if __rawref__ doesn't exist */ + PyErr_Clear(); + } + + /* Create new reference */ + ReferenceTypeObject* self; + self = (ReferenceTypeObject*) type->tp_alloc(type, 0); + if (self != NULL) + { + self->object_id = (Py_ssize_t) obj; + self->valid = 1; + + /* Set obj.__rawref__ to this new reference using cached attr name */ + if (PyObject_SetAttr(obj, rawref_attr_name, (PyObject*) self) < 0) + { + /* If we can't set __rawref__, just clear the error and continue */ + PyErr_Clear(); + } + } + return (PyObject*) self; +} + +/* ReferenceType.__init__ */ +static int ReferenceType_init(ReferenceTypeObject* self, PyObject* args, PyObject* kwds) +{ + /* __new__ already did all the work, including setting object_id and valid */ + /* Skip argument parsing since __new__ already validated them */ + /* This saves ~5-10% overhead on object creation */ + return 0; +} + +/* ReferenceType.__call__() - dereference the object */ +static PyObject* ReferenceType_call(ReferenceTypeObject* self, PyObject* args, PyObject* kwds) +{ + PyObject* obj; + + if (!self->valid) + { + Py_RETURN_NONE; + } + + /* Use _PyObject_FromStackRefSteal or ctypes approach */ + /* We need to find the object by its id */ + /* This is the tricky part - we need to convert id back to object */ + + /* Use ctypes.cast to convert id to PyObject* */ + obj = (PyObject*) self->object_id; + + /* Check if the object is still alive by verifying ref count > 0 */ + /* This is somewhat unsafe but matches the intended behavior */ + if (Py_REFCNT(obj) > 0) + { + Py_INCREF(obj); + return obj; + } + + /* Object no longer valid */ + self->valid = 0; + Py_RETURN_NONE; +} + +/* ReferenceType.invalidate() */ +static PyObject* ReferenceType_invalidate(ReferenceTypeObject* self, PyObject* Py_UNUSED(ignored)) +{ + self->valid = 0; + Py_RETURN_NONE; +} + +/* ReferenceType.is_valid property getter */ +static PyObject* ReferenceType_is_valid(ReferenceTypeObject* self, void* closure) +{ + return PyBool_FromLong(self->valid); +} + +/* ReferenceType.__class_getitem__() - support for generic type subscripting */ +static PyObject* ReferenceType_class_getitem(PyObject* cls, PyObject* item) +{ + /* Just return the class itself, ignore the type parameter */ + /* This allows rawref.ref[T] to work at runtime like weakref.ref[T] */ + Py_INCREF(cls); + return cls; +} + +/* Method definitions */ +static PyMethodDef ReferenceType_methods[] = { + {"invalidate", (PyCFunction) ReferenceType_invalidate, METH_NOARGS, + "Invalidate the reference, making it return None on dereference."}, + {"__class_getitem__", (PyCFunction) ReferenceType_class_getitem, METH_O | METH_CLASS, + "Support for generic type subscripting (e.g., ref[T])."}, + {NULL} /* Sentinel */ +}; + +/* Property definitions */ +static PyGetSetDef ReferenceType_getsetters[] = { + {"is_valid", (getter) ReferenceType_is_valid, NULL, "Check if the reference is still valid (read-only).", NULL}, + {NULL} /* Sentinel */ +}; + +/* Type definition */ +static PyTypeObject ReferenceTypeType = { + PyVarObject_HEAD_INIT(NULL, 0).tp_name = "_rawref.ReferenceType", + .tp_doc = "A mutable reference holder that stores an object ID.", + .tp_basicsize = sizeof(ReferenceTypeObject), + .tp_itemsize = 0, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_new = ReferenceType_new, + .tp_init = (initproc) ReferenceType_init, + .tp_call = (ternaryfunc) ReferenceType_call, + .tp_methods = ReferenceType_methods, + .tp_getset = ReferenceType_getsetters, +}; + +/* Module definition */ +static PyModuleDef rawrefmodule = { + PyModuleDef_HEAD_INIT, + .m_name = "_rawref", + .m_doc = "C extension providing mutable reference class ReferenceType (internal module).", + .m_size = -1, +}; + +/* Module initialization */ +PyMODINIT_FUNC PyInit__rawref(void) +{ + PyObject* m; + ReferenceTypeObject* null_ref; + + if (PyType_Ready(&ReferenceTypeType) < 0) + return NULL; + + m = PyModule_Create(&rawrefmodule); + if (m == NULL) + return NULL; + + /* Cache the __rawref__ attribute name for faster lookups */ + rawref_attr_name = PyUnicode_InternFromString("__rawref__"); + if (rawref_attr_name == NULL) + { + Py_DECREF(m); + return NULL; + } + + Py_INCREF(&ReferenceTypeType); + if (PyModule_AddObject(m, "ReferenceType", (PyObject*) &ReferenceTypeType) < 0) + { + Py_DECREF(&ReferenceTypeType); + Py_DECREF(m); + return NULL; + } + + /* Add 'ref' as an alias for 'ReferenceType' (like weakref.ref) */ + Py_INCREF(&ReferenceTypeType); + if (PyModule_AddObject(m, "ref", (PyObject*) &ReferenceTypeType) < 0) + { + Py_DECREF(&ReferenceTypeType); + Py_DECREF(m); + return NULL; + } + + /* Create NULL constant - an invalid reference */ + null_ref = (ReferenceTypeObject*) ReferenceTypeType.tp_alloc(&ReferenceTypeType, 0); + if (null_ref != NULL) + { + null_ref->object_id = 0; + null_ref->valid = 0; + if (PyModule_AddObject(m, "NULL", (PyObject*) null_ref) < 0) + { + Py_DECREF(null_ref); + Py_DECREF(m); + return NULL; + } + } + + return m; +} diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/setup.py b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/setup.py new file mode 100644 index 0000000000..b6a8c58c6d --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/setup.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import setup, Extension + +rawref_module = Extension( + '_rawref', + sources=['rawrefmodule.c'], +) + +setup( + name='rawref', + version='1.0', + description='C extension providing mutable reference class Ref[T]', + ext_modules=[rawref_module], +) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/test_rawref.py b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/test_rawref.py new file mode 100644 index 0000000000..fab925869f --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/rawref/test_rawref.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script for the rawref module.""" + +import os +import sys + +# Add parent directory to path to import the rawref package +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +try: + from rawref import NULL, ReferenceType, ref +except ImportError as e: + print(f"Error importing rawref: {e}") + print("Make sure to build the extension first with: python setup.py build_ext --inplace") + sys.exit(1) + + +class TestObject: + """Test class with __del__ that invalidates references.""" + + def __init__(self, value): + self.value = value + self.refs = [] # Store references to invalidate + + def __del__(self): + print(f"TestObject({self.value}).__del__ called, invalidating {len(self.refs)} references") + for ref in self.refs: + ref.invalidate() + + +def test_basic_reference(): + """Test basic reference creation and dereferencing.""" + print("\n=== Test 1: Basic Reference ===") + + obj = TestObject(42) + r = ref(obj) + obj.refs.append(r) + + print(f"Created object with value: {obj.value}") + print(f"Reference is_valid: {r.is_valid}") + + # Dereference + dereferenced = r() + print(f"Dereferenced object: {dereferenced}") + if dereferenced: + print(f"Dereferenced value: {dereferenced.value}") + + assert r.is_valid, "Reference should be valid" + assert dereferenced is obj, "Dereferenced object should be the same as original" + + +def test_invalidation(): + """Test manual invalidation.""" + print("\n=== Test 2: Manual Invalidation ===") + + obj = TestObject(123) + r = ref(obj) + + print(f"Before invalidation - is_valid: {r.is_valid}") + print(f"Dereferenced: {r()}") + + r.invalidate() + + print(f"After invalidation - is_valid: {r.is_valid}") + print(f"Dereferenced: {r()}") + + assert not r.is_valid, "Reference should be invalid after invalidate()" + assert r() is None, "Dereferencing invalid reference should return None" + + +def test_del_invalidation(): + """Test invalidation via __del__.""" + print("\n=== Test 3: Invalidation via __del__ ===") + + r = None + + # Create object in a scope + def create_and_ref(): + obj = TestObject(999) + nonlocal r + r = ref(obj) + obj.refs.append(r) + + print(f"Inside scope - is_valid: {r.is_valid}") + print(f"Inside scope - dereferenced: {r()}") + + create_and_ref() + + # Object should be deleted and reference invalidated + print(f"After scope - is_valid: {r.is_valid}") + print(f"After scope - dereferenced: {r()}") + + assert not r.is_valid, "Reference should be invalid after object deletion" + assert r() is None, "Dereferencing should return None after object deletion" + + +def test_multiple_references(): + """Test that singleton pattern returns same reference.""" + print("\n=== Test 4: Singleton Pattern - Same Reference ===") + + obj = TestObject(555) + r1 = ref(obj) + r2 = ref(obj) + obj.refs.extend([r1, r2]) + + print(f"r1 is r2: {r1 is r2}") + + assert r1 is r2, "With singleton pattern, should return the same reference" + + # Invalidate r1 (which is the same as r2) + r1.invalidate() + + print("After invalidating r1:") + print(f" r1 is_valid: {r1.is_valid}, dereferenced: {r1()}") + print(f" r2 is_valid: {r2.is_valid}, dereferenced: {r2()}") + + # Since they're the same object, both are invalid + assert not r1.is_valid, "r1 should be invalid" + assert not r2.is_valid, "r2 should also be invalid (same object)" + + +def test_alias_equivalence(): + """Test that ref and ReferenceType are the same.""" + print("\n=== Test 5: ref and ReferenceType are equivalent ===") + + obj = TestObject(777) + r1 = ref(obj) + r2 = ReferenceType(obj) + + print(f"ref is ReferenceType: {ref is ReferenceType}") + print(f"type(r1): {type(r1)}") + print(f"type(r2): {type(r2)}") + print(f"r1 is r2: {r1 is r2}") + + assert ref is ReferenceType, "ref should be an alias for ReferenceType" + assert r1 is r2, "Should return the same reference object (singleton pattern)" + + +def test_singleton_pattern(): + """Test that ref(obj) returns the same reference if __rawref__ is valid.""" + print("\n=== Test 6: Singleton Pattern ===") + + obj = TestObject(888) + + # First call creates a new reference + r1 = ref(obj) + print(f"First ref: {r1}, valid: {r1.is_valid}") + + # Second call returns the same reference + r2 = ref(obj) + print(f"Second ref: {r2}, valid: {r2.is_valid}") + print(f"r1 is r2: {r1 is r2}") + + assert r1 is r2, "Should return the same reference object" + + # After invalidation, a new call creates a new reference + r1.invalidate() + print(f"After invalidation: r1.valid={r1.is_valid}, r2.valid={r2.is_valid}") + + r3 = ref(obj) + print(f"Third ref (after invalidation): {r3}, valid: {r3.is_valid}") + print(f"r1 is r3: {r1 is r3}") + + assert r1 is not r3, "Should create a new reference after invalidation" + assert r3.is_valid, "New reference should be valid" + + +def test_null_constant(): + """Test the NULL constant.""" + print("\n=== Test 7: NULL Constant ===") + + print(f"NULL: {NULL}") + print(f"NULL.is_valid: {NULL.is_valid}") + print(f"NULL(): {NULL()}") + print(f"type(NULL): {type(NULL)}") + + assert not NULL.is_valid, "NULL should be invalid" + assert NULL() is None, "NULL() should return None" + + # Test using NULL to initialize __rawref__ + class MyClass: + __rawref__ = NULL + + obj = MyClass() + r = ref(obj) + print("After creating ref for obj with __rawref__=NULL:") + print(f" obj.__rawref__ is r: {obj.__rawref__ is r}") + print(f" r.is_valid: {r.is_valid}") + + assert obj.__rawref__ is r, "Should update __rawref__ to new reference" + assert r.is_valid, "New reference should be valid" + + +def test_hidden_object_id(): + """Test that object_id is hidden.""" + print("\n=== Test 8: Hidden object_id ===") + + obj = TestObject(123) + r = ref(obj) + + # Try to access object_id - should raise AttributeError + try: + _ = r.object_id + assert False, "Should not be able to access object_id" + except AttributeError: + print("āœ“ object_id is hidden (AttributeError raised)") + + print() + + +def test_is_valid_readonly(): + """Test that is_valid is read-only.""" + print("\n=== Test 9: is_valid is read-only ===") + + obj = TestObject(456) + r = ref(obj) + + print(f"r.is_valid: {r.is_valid}") + + # Try to set is_valid - should raise AttributeError + try: + r.is_valid = False + assert False, "Should not be able to set is_valid" + except AttributeError: + print("āœ“ is_valid is read-only (AttributeError raised)") + + print() + + +if __name__ == "__main__": + print("Testing rawref module...") + + try: + test_basic_reference() + test_invalidation() + test_del_invalidation() + test_multiple_references() + test_alias_equivalence() + test_singleton_pattern() + test_null_constant() + test_hidden_object_id() + test_is_valid_readonly() + + print("\n" + "=" * 50) + print("All tests passed!") + print("=" * 50) + except AssertionError as e: + print(f"\nāŒ Test failed: {e}") + sys.exit(1) + except Exception as e: + print(f"\nāŒ Unexpected error: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) diff --git a/tensorrt_llm/runtime/kv_cache_manager_v2/setup_mypyc.py b/tensorrt_llm/runtime/kv_cache_manager_v2/setup_mypyc.py new file mode 100644 index 0000000000..cd80e4197b --- /dev/null +++ b/tensorrt_llm/runtime/kv_cache_manager_v2/setup_mypyc.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Setup script for compiling kv_cache_manager_v2 with mypyc. + +Usage (from project root): + python tensorrt_llm/runtime/kv_cache_manager_v2/setup_mypyc.py build_ext --inplace + +Or use the build script: + ./tensorrt_llm/runtime/kv_cache_manager_v2/build_mypyc.sh +""" + +import os +import sys + +from mypyc.build import mypycify +from setuptools import setup + +# Set environment variables BEFORE importing mypyc +os.environ["MYPY_FORCE_COLOR"] = "0" + +# Write a strict mypy config that won't check external files +mypy_config_path = os.path.abspath("mypy_mypyc_build.ini") +with open(mypy_config_path, "w") as f: + f.write("""[mypy] +# Critical: Don't follow any imports outside the specified files +follow_imports = skip +follow_imports_for_stubs = False + +# Ignore missing imports completely +ignore_missing_imports = True + +# Allow all untyped code +allow_untyped_calls = True +allow_untyped_defs = True +allow_incomplete_defs = True +allow_untyped_globals = True +check_untyped_defs = False + +# Disable all warnings that might cause errors +disallow_untyped_calls = False +disallow_untyped_defs = False +disallow_incomplete_defs = False +warn_return_any = False +warn_unused_ignores = False + +# Disable type validation errors (for external types like drv.CUstream) +disable_error_code = valid-type +""") + +# Point mypy to this config by adding to sys.argv before mypyc runs +sys.argv.extend(["--config-file", mypy_config_path]) + +# List all Python modules in kv_cache_manager_v2 to compile +# +# EXCLUDED FILES: +# - _exceptions.py: inherits from builtin Exception classes (mypyc limitation) +# +modules = [ + # Main module files + "kv_cache_manager_v2/__init__.py", + "kv_cache_manager_v2/_block_radix_tree.py", + "kv_cache_manager_v2/_common.py", + "kv_cache_manager_v2/_config.py", + "kv_cache_manager_v2/_copy_engine.py", + "kv_cache_manager_v2/_cuda_virt_mem.py", + "kv_cache_manager_v2/_exceptions.py", + "kv_cache_manager_v2/_life_cycle_registry.py", + "kv_cache_manager_v2/_page.py", + "kv_cache_manager_v2/_storage_manager.py", + "kv_cache_manager_v2/_utils.py", + # _core submodule + "kv_cache_manager_v2/_core/__init__.py", + "kv_cache_manager_v2/_core/_kv_cache_manager.py", + "kv_cache_manager_v2/_core/_kv_cache.py", + # _eviction_controller submodule + "kv_cache_manager_v2/_eviction_controller/__init__.py", + "kv_cache_manager_v2/_eviction_controller/_eviction_controller.py", + # _storage submodule + "kv_cache_manager_v2/_storage/__init__.py", + "kv_cache_manager_v2/_storage/_config.py", + "kv_cache_manager_v2/_storage/_core.py", +] + +print(f"Compiling {len(modules)} modules with mypyc...") +print("Excluded: None") +print("") + +try: + ext_modules = mypycify( + modules, + opt_level="3", # Maximum optimization + multi_file=True, # Allow cross-module references (needed for inheritance) + verbose=True, # Show what's being compiled + separate=False, # Compile into single .so (required for cross-module inheritance) + strip_asserts=False, # Keep assertions for debugging + ) + +except Exception as e: + print(f"Error during mypyc compilation: {e}") + sys.exit(1) +finally: + # Cleanup temp config + if os.path.exists(mypy_config_path): + try: + os.remove(mypy_config_path) + except OSError: + pass + + # Remove --config-file arguments from sys.argv before calling setup() + # This prevents setuptools from seeing arguments it doesn't understand + while "--config-file" in sys.argv: + idx = sys.argv.index("--config-file") + sys.argv.pop(idx) # Remove '--config-file' + if idx < len(sys.argv): # Remove the path that follows it + sys.argv.pop(idx) + +setup( + name="kv_cache_manager_v2_compiled", + ext_modules=ext_modules, + packages=["kv_cache_manager_v2.rawref"], + package_data={ + "kv_cache_manager_v2": ["*.pyi", "**/*.pyi"], + }, + python_requires=">=3.8", +) diff --git a/tests/integration/defs/agg_unit_mem_df.csv b/tests/integration/defs/agg_unit_mem_df.csv index d5e54357db..5f9cf79953 100644 --- a/tests/integration/defs/agg_unit_mem_df.csv +++ b/tests/integration/defs/agg_unit_mem_df.csv @@ -12,6 +12,7 @@ unittest/trt/attention/test_gpt_attention_IFB.py,NVIDIA A10,17, unittest/trt/attention/test_gpt_attention_no_cache.py,NVIDIA A10,23, unittest/trt/model/test_mamba.py,NVIDIA A10,12, unittest/trt/model/test_llama.py,NVIDIA A10,3, +unittest/kv_cache_manager_v2_tests/,NVIDIA A10,8, "unittest/trt/attention/test_gpt_attention.py -k ""partition0""",NVIDIA A10,14, "unittest/trt/attention/test_gpt_attention.py -k ""partition1""",NVIDIA A10,10, "unittest/trt/attention/test_gpt_attention.py -k ""partition2""",NVIDIA A10,3, @@ -42,6 +43,7 @@ unittest/trt/quantization/test_weight_only_groupwise_quant_matmul.py,NVIDIA H100 unittest/trt/attention/test_gpt_attention_IFB.py,NVIDIA H100 80GB HBM3,11, unittest/trt/attention/test_gpt_attention_no_cache.py,NVIDIA H100 80GB HBM3,13, unittest/trt/model/test_mamba.py,NVIDIA H100 80GB HBM3,10, +unittest/kv_cache_manager_v2_tests/,NVIDIA H100 80GB HBM3,8, "unittest/trt/attention/test_gpt_attention.py -k ""partition0""",NVIDIA L40S,14, "unittest/trt/attention/test_gpt_attention.py -k ""partition1""",NVIDIA L40S,10, "unittest/trt/attention/test_gpt_attention.py -k ""partition2""",NVIDIA L40S,6, @@ -64,6 +66,7 @@ unittest/trt/quantization/test_weight_only_groupwise_quant_matmul.py,NVIDIA H100 unittest/trt/attention/test_gpt_attention_IFB.py,NVIDIA H100 PCIe,11, unittest/trt/attention/test_gpt_attention_no_cache.py,NVIDIA H100 PCIe,13, unittest/trt/model/test_mamba.py,NVIDIA H100 PCIe,10, +unittest/kv_cache_manager_v2_tests/,NVIDIA H100 PCIe,8, llmapi-tp-2gpu,NVIDIA H100 NVL,1, unittest/llmapi/test_llm_models_multi_gpu.py,NVIDIA H100 NVL,1, unittest/trt/model/test_gptneox.py,NVIDIA H100 NVL,7, @@ -80,6 +83,7 @@ unittest/trt/quantization/test_weight_only_groupwise_quant_matmul.py,NVIDIA H100 unittest/trt/attention/test_gpt_attention_IFB.py,NVIDIA H100 NVL,11, unittest/trt/attention/test_gpt_attention_no_cache.py,NVIDIA H100 NVL,13, unittest/trt/model/test_mamba.py,NVIDIA H100 NVL,10, +unittest/kv_cache_manager_v2_tests/,NVIDIA H100 NVL,8, llmapi-tp-2gpu,NVIDIA H100,1, unittest/llmapi/test_llm_models_multi_gpu.py,NVIDIA H100,1, unittest/trt/model/test_gptneox.py,NVIDIA H100,7, @@ -96,6 +100,7 @@ unittest/trt/quantization/test_weight_only_groupwise_quant_matmul.py,NVIDIA H100 unittest/trt/attention/test_gpt_attention_IFB.py,NVIDIA H100,11, unittest/trt/attention/test_gpt_attention_no_cache.py,NVIDIA H100,13, unittest/trt/model/test_mamba.py,NVIDIA H100,10, +unittest/kv_cache_manager_v2_tests/,NVIDIA H100,8, "unittest/trt/attention/test_gpt_attention.py -k ""partition0""",NVIDIA L40,14, "unittest/trt/attention/test_gpt_attention.py -k ""partition1""",NVIDIA L40,10, "unittest/trt/attention/test_gpt_attention.py -k ""partition2""",NVIDIA L40,6, @@ -111,6 +116,7 @@ unittest/_torch/misc,NVIDIA B200,4, unittest/_torch/speculative,NVIDIA B200,4, unittest/_torch/thop/parallel,NVIDIA B200,16, "unittest/_torch/auto_deploy/unit/singlegpu -k ""not test_trtllm_bench_backend_comparison""",NVIDIA B200,4, +unittest/kv_cache_manager_v2_tests/,NVIDIA B200,8, unittest/_torch/attention,NVIDIA H100,4, unittest/_torch/misc,NVIDIA H100,4, unittest/_torch/thop/parallel,NVIDIA H100,16, diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index efc5f92021..c8dc811a37 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -185,6 +185,7 @@ l0_a10: - unittest/trt/attention/test_gpt_attention_IFB.py - unittest/trt/attention/test_gpt_attention_no_cache.py - examples/test_whisper.py::test_llm_whisper_general[large-v3-disable_gemm_plugin-enable_attention_plugin-disable_weight_only-float16-nb:1-use_cpp_runtime] + - unittest/kv_cache_manager_v2_tests/ # 4 min - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 99da33194f..b0f6d9c356 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -87,6 +87,7 @@ l0_b200: - unittest/tools/test_layer_wise_benchmarks.py::test_nemotron_gen_dep[1] - unittest/tools/test_layer_wise_benchmarks.py::test_qwen3_next_gen_tep[1] - unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8 + - unittest/kv_cache_manager_v2_tests/ - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 9d98ff7910..605943fefe 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -401,6 +401,7 @@ l0_h100: - unittest/trt/model/test_gpt_e2e.py # 3 mins / 6 mins on H100 - unittest/trt/attention/test_gpt_attention_no_cache.py - examples/test_gpt.py::test_gpt_oss_20b_lora_torch[gpt-oss-20b-lora-adapter_NIM_r8-gpt-oss-20b] + - unittest/kv_cache_manager_v2_tests/ # 4 min - condition: ranges: system_gpu_count: diff --git a/tests/unittest/kv_cache_manager_v2_tests/fake_engine.py b/tests/unittest/kv_cache_manager_v2_tests/fake_engine.py new file mode 100644 index 0000000000..4f13aecb21 --- /dev/null +++ b/tests/unittest/kv_cache_manager_v2_tests/fake_engine.py @@ -0,0 +1,206 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from collections.abc import Sequence +from functools import cached_property +from importlib.util import find_spec +from typing import TYPE_CHECKING, NamedTuple + +if not TYPE_CHECKING and find_spec("kv_cache_manager_v2") is not None: + from kv_cache_manager_v2 import ( + AttentionLayerConfig, + BeamIndex, + CudaStream, + DataRole, + KVCacheManagerConfig, + LayerId, + TokenIdExt, + _KVCache, + ) + from kv_cache_manager_v2._common import BAD_PAGE_INDEX, NDEBUG, MemAddress + from kv_cache_manager_v2._utils import ( + div_up, + exact_div, + get_uniform_attribute, + overlap, + typed_range, + value_or, + ) +else: + from tensorrt_llm.runtime.kv_cache_manager_v2 import ( + AttentionLayerConfig, + BeamIndex, + CudaStream, + DataRole, + KVCacheManagerConfig, + LayerId, + TokenIdExt, + _KVCache, + ) + from tensorrt_llm.runtime.kv_cache_manager_v2._common import BAD_PAGE_INDEX, NDEBUG, MemAddress + from tensorrt_llm.runtime.kv_cache_manager_v2._utils import ( + div_up, + exact_div, + get_uniform_attribute, + overlap, + typed_range, + value_or, + ) + +import os + +from dynamic_path_manager import DynamicPathManager + +with DynamicPathManager(os.path.dirname(os.path.abspath(__file__)), clear_cache=False): + from kernels import check_values, fill_values + + +class Step(NamedTuple): + kv_cache: _KVCache + input: list[TokenIdExt] # when empty, just check history + history: list[TokenIdExt] + + +class Role: + """Constants for data roles in KV cache management.""" + + KEY = DataRole("key") + VALUE = DataRole("value") + KEY_BLOCK_QUANT = DataRole("key_block_quant") + VALUE_BLOCK_QUANT = DataRole("value_block_quant") + + +roles = (Role.KEY, Role.VALUE, Role.KEY_BLOCK_QUANT, Role.VALUE_BLOCK_QUANT) + + +class FakeEngine: + cfg: KVCacheManagerConfig + + def __init__(self, config: KVCacheManagerConfig) -> None: + super().__init__() + self.cfg = config + + @property + def tokens_per_block(self) -> int: + return self.cfg.tokens_per_block + + @cached_property + def layers(self) -> dict[LayerId, AttentionLayerConfig]: + return { + layer.layer_id: layer + for layer in sorted(self.cfg.layers, key=lambda layer: layer.layer_id) + } + + def execute(self, batch: Sequence[Step], stream: CudaStream) -> None: + assert batch + manager = get_uniform_attribute(batch, lambda step: step.kv_cache.manager) + for kv_cache, input, history in batch: + for layer_id, layer_cfg in self.layers.items(): + for buf_id, buf in enumerate(layer_cfg.buffers): + role = buf.role + assert NDEBUG or buf.size == manager.get_page_stride(layer_id, role) + for beam in typed_range(kv_cache.beam_width): + # check history + self._check_pages(kv_cache, layer_id, buf_id, beam, history, stream) + # write new token + if input: + self._write_new_tokens( + kv_cache, len(history), layer_id, buf_id, beam, input, stream + ) + + def _check_pages( + self, + kv_cache: _KVCache, + layer_id: LayerId, + buf_id: int, + beam: BeamIndex, + history: Sequence[TokenIdExt], + stream: CudaStream, + ): + manager = kv_cache.manager + tokens_per_block = self.tokens_per_block + layer_cfg = self.layers[layer_id] + buf = layer_cfg.buffers[buf_id] + role = buf.role + token_bytes = exact_div(buf.size, tokens_per_block) + pool = manager.get_mem_pool_base_address(layer_id, role) + stride = manager.get_page_stride(layer_id, role) + lc_id = manager._storage._layer_to_life_cycle_ids[layer_id] + pages = kv_cache.get_page_indices(lc_id, beam) + capacity = kv_cache.capacity + history_len = len(history) + assert len(history) == history_len + window = ( + (0, capacity) + if layer_cfg.window_size is None + else (max(0, history_len + 1 - layer_cfg.window_size), capacity) + ) + sink = value_or(layer_cfg.num_sink_tokens, 0) + # check history + for ordinal, page in enumerate(pages): + if page == BAD_PAGE_INDEX: + continue + page_range = (tokens_per_block * ordinal, tokens_per_block * (ordinal + 1)) + need_page = overlap(page_range, (0, sink)) or overlap(page_range, window) + if need_page: + assert page != BAD_PAGE_INDEX + else: + assert kv_cache.history_length != history_len or page == BAD_PAGE_INDEX + addr = MemAddress(pool + stride * page) + tokens = history[tokens_per_block * ordinal : tokens_per_block * (ordinal + 1)] + check_values(addr, token_bytes, layer_id, buf_id, beam, tokens, stream) + + def _write_new_tokens( + self, + kv_cache: _KVCache, + history_len: int, + layer_id: LayerId, + buf_id: int, + beam: BeamIndex, + input: Sequence[TokenIdExt], + stream: CudaStream, + ): + manager = kv_cache.manager + tokens_per_block = self.tokens_per_block + layer_cfg = self.layers[layer_id] + buf = layer_cfg.buffers[buf_id] + role = buf.role + token_bytes = exact_div(buf.size, self.tokens_per_block) + pool = manager.get_mem_pool_base_address(layer_id, role) + stride = manager.get_page_stride(layer_id, role) + lc_id = manager._storage._layer_to_life_cycle_ids[layer_id] + pages = kv_cache.get_page_indices(lc_id, beam)[ + : div_up(history_len + len(input), tokens_per_block) + ] + capacity = kv_cache.capacity + input_range = (history_len, history_len + len(input)) + assert input_range[1] <= capacity + ordinal_beg = input_range[0] // tokens_per_block + pages = itertools.islice(pages, ordinal_beg, None) + ordinal = None + for i, page in enumerate(pages): + ordinal = ordinal_beg + i + assert page != BAD_PAGE_INDEX + page_range = (tokens_per_block * ordinal, tokens_per_block * (ordinal + 1)) + batch_range = tuple(i for i in overlap(input_range, page_range)) + assert batch_range + tokens = input[(batch_range[0] - history_len) : (batch_range[1] - history_len)] + addr = MemAddress( + pool + stride * page + token_bytes * (batch_range[0] % tokens_per_block) + ) + # print('layer_id={}, buf_id={}, beam={}, i={}, addr={}, tokens={}'.format( + # layer_id, buf_id, beam, i, addr, tokens)) + fill_values(addr, token_bytes, layer_id, buf_id, beam, tokens, stream) + assert ordinal is None or ordinal + 1 == div_up(input_range[1], tokens_per_block) diff --git a/tests/unittest/kv_cache_manager_v2_tests/kernels.py b/tests/unittest/kv_cache_manager_v2_tests/kernels.py new file mode 100644 index 0000000000..8105b3cdcd --- /dev/null +++ b/tests/unittest/kv_cache_manager_v2_tests/kernels.py @@ -0,0 +1,336 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import ctypes +from collections.abc import Sequence +from functools import lru_cache +from importlib.util import find_spec +from typing import TYPE_CHECKING, Iterator + +import cuda.bindings.driver as drv + +try: + from cuda.core import Kernel, ObjectCode, Program, ProgramOptions +except ImportError: + from cuda.core.experimental import Kernel, Program, ProgramOptions + from cuda.core.experimental._module import ObjectCode + +if not TYPE_CHECKING and find_spec("kv_cache_manager_v2") is not None: + from kv_cache_manager_v2._common import CudaStream, LayerId, MemAddress, TokenIdExt + from kv_cache_manager_v2._utils import _unwrap, div_up, exact_div +else: + from tensorrt_llm.runtime.kv_cache_manager_v2._common import ( + CudaStream, + LayerId, + MemAddress, + TokenIdExt, + ) + from tensorrt_llm.runtime.kv_cache_manager_v2._utils import _unwrap, div_up, exact_div + +_SLEEP_TIME_NS: int = 0 + + +@contextlib.contextmanager +def enable_kernel_delay() -> Iterator[None]: + global _SLEEP_TIME_NS + _SLEEP_TIME_NS = 30_000 + yield + _SLEEP_TIME_NS = 0 + + +@lru_cache(maxsize=None) +def get_program(debug: bool, max_tokens: int, sleep_time: int) -> ObjectCode: + assert max_tokens > 0 and (max_tokens & (max_tokens - 1)) == 0, ( + "max_tokens must be a power of 2" + ) + code = r""" +#if !defined(__CUDACC_RTC__) +#include +#include +#endif + +#ifdef NDEBUG +__device__ inline void check(bool condition) { + if (!condition) { + asm volatile("trap;" ::: "memory"); + } +} +#else +#define check assert +#endif + +using uint32_t = unsigned int; +using uint16_t = unsigned short; + +constexpr uint32_t sleepTime = SLEEP_TIME_NS; + +struct alignas(16) Value { + uint32_t token; + uint32_t layer; + uint32_t role; + uint32_t beam; + + __device__ inline bool operator==(Value const& other) const { + return token == other.token && layer == other.layer && role == other.role && beam == other.beam; + } + __device__ inline bool operator!=(Value const& other) const { + return !(*this == other); + } +}; + +constexpr uint32_t kMAX_TOKENS = MAX_TOKENS; + +struct Tokens { + uint32_t tokens[kMAX_TOKENS]; +}; + +extern "C" __global__ void fillValues(Value* data, uint32_t valuesPerToken, uint32_t layer, + uint32_t buf_id, uint32_t beam, __grid_constant__ const Tokens tokens, uint32_t numTokens) { + if (sleepTime > 0) { + __nanosleep(sleepTime); + } + check(numTokens <= kMAX_TOKENS); + auto const tid = (static_cast(blockIdx.x) * blockDim.x) + threadIdx.x; + auto const idxToken = tid / valuesPerToken; + if (idxToken >= numTokens) { + return; + } + auto const token = tokens.tokens[idxToken]; + auto const value = Value{token, layer, buf_id, beam}; + data[tid] = value; +} + +__device__ inline void assertEq(Value const& a, Value const& b) { +#ifndef NDEBUG + if (a != b) { + printf("(%d, %d, %d, %d) != (%d, %d, %d, %d)\n", + a.token, a.layer, a.role, a.beam, + b.token, b.layer, b.role, b.beam); + } +#endif + check(a == b); +} + +extern "C" __global__ void checkValues(Value const* data, uint32_t valuesPerToken, uint32_t layer, + uint32_t buf_id, uint32_t beam, __grid_constant__ const Tokens tokens, uint32_t numTokens) { + if (sleepTime > 0) { + __nanosleep(sleepTime); + } + check(numTokens <= kMAX_TOKENS); + auto const tid = (static_cast(blockIdx.x) * blockDim.x) + threadIdx.x; + auto const idxToken = tid / valuesPerToken; + if (idxToken >= numTokens) { + return; + } + auto const token = tokens.tokens[idxToken]; + auto const value = Value{token, layer, buf_id, beam}; + assertEq(data[tid], value); +} + """ + macros = [("MAX_TOKENS", str(max_tokens)), ("SLEEP_TIME_NS", str(sleep_time))] + program_options = ProgramOptions(std="c++17", lineinfo=True, debug=debug, define_macro=macros) # type: ignore[arg-type] + if not debug: + program_options.use_fast_math = True + prog = Program(code, code_type="c++", options=program_options) + mod = prog.compile("cubin", name_expressions=("fillValues", "checkValues")) + return mod + + +def get_kernel(name: str, num_tokens: int, sleep_time: int) -> tuple[Kernel, int]: + assert num_tokens > 0 + + @lru_cache(maxsize=None) + def impl(name: str, max_tokens: int, sleep_time: int) -> Kernel: + assert name in ("fillValues", "checkValues") + assert max_tokens != 0 and (max_tokens & (max_tokens - 1)) == 0, ( + "max_tokens must be a power of 2" + ) + debug = False + # debug = not NDEBUG + return get_program(debug, max_tokens, sleep_time).get_kernel(name) + + # Round up to the next power of two + max_tokens = 2 ** ((num_tokens - 1).bit_length()) + return impl(name, max_tokens, sleep_time), max_tokens + + +class Value(ctypes.Structure): + _fields_ = [ + ("token", ctypes.c_uint32), + ("layer", ctypes.c_uint32), + ("buf_id", ctypes.c_uint32), + ("beam", ctypes.c_uint32), + ] + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Value): + return NotImplemented + return ( + self.token == other.token + and self.layer == other.layer + and self.buf_id == other.buf_id + and self.beam == other.beam + ) + + def __str__(self) -> str: + return ( + f"Value(token={self.token}, layer={self.layer}, buf_id={self.buf_id}, beam={self.beam})" + ) + + +@lru_cache(maxsize=None) +def _get_ctypes_struct(max_tokens: int) -> type[ctypes.Structure]: + class Tokens(ctypes.Structure): + _fields_ = [ + ("tokens", ctypes.c_uint32 * max_tokens), + ] + + Tokens.__name__ = f"Tokens_{max_tokens}" + return Tokens + + +def _make_tokens(tokens: Sequence[TokenIdExt], max_tokens: int) -> ctypes.Structure: + assert len(tokens) <= max_tokens + padded = list(tokens) + [0] * (max_tokens - len(tokens)) + Tokens = _get_ctypes_struct(max_tokens) + return Tokens( + tokens=(ctypes.c_uint32 * max_tokens)( + *[ + t if isinstance(t, int) else int.from_bytes(t[:4], "little", signed=False) + for t in padded + ] + ) + ) + + +def fill_values( + address: MemAddress, + bytes_per_token: int, + layer: LayerId, + buf_id: int, + beam: int, + tokens: Sequence[TokenIdExt], + stream: CudaStream, +): + values_per_token = exact_div(bytes_per_token, ctypes.sizeof(Value)) + num_tokens = len(tokens) + if num_tokens == 0: + return + kernel, max_tokens = get_kernel("fillValues", len(tokens), _SLEEP_TIME_NS) + args = ( + address, + values_per_token, + layer, + buf_id, + beam, + _make_tokens(tokens, max_tokens), + num_tokens, + ) + arg_types = ( + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + None, + ctypes.c_uint32, + ) + num_threads = values_per_token * num_tokens + cta_size = 256 + _unwrap( + drv.cuLaunchKernel( + kernel._handle, + div_up(num_threads, cta_size), + 1, + 1, + cta_size, + 1, + 1, + 0, + stream, + (args, arg_types), + 0, + ) + ) + + +def check_values( + address: MemAddress, + bytes_per_token: int, + layer: LayerId, + buf_id: int, + beam: int, + tokens: Sequence[TokenIdExt], + stream: CudaStream, +): + values_per_token = exact_div(bytes_per_token, ctypes.sizeof(Value)) + num_tokens = len(tokens) + if num_tokens == 0: + return + kernel, max_tokens = get_kernel("checkValues", len(tokens), _SLEEP_TIME_NS) + args = ( + address, + values_per_token, + layer, + buf_id, + beam, + _make_tokens(tokens, max_tokens), + num_tokens, + ) + arg_types = ( + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + ctypes.c_uint32, + None, + ctypes.c_uint32, + ) + num_threads = values_per_token * num_tokens + cta_size = 256 + _unwrap( + drv.cuLaunchKernel( + kernel._handle, + div_up(num_threads, cta_size), + 1, + 1, + cta_size, + 1, + 1, + 0, + stream, + (args, arg_types), + 0, + ) + ) + + +def debug_dump_tokens( + addr: MemAddress, token_bytes: int, num_tokens: int, stream: CudaStream +) -> Iterator[Value]: + if num_tokens == 0: + return + val_size = ctypes.sizeof(Value) + values_per_token = exact_div(token_bytes, val_size) + host_buf = (Value * values_per_token * num_tokens)() + ptr = ctypes.addressof(host_buf) + _unwrap(drv.cuMemcpyDtoHAsync(ptr, addr, num_tokens * token_bytes, stream)) + _unwrap(drv.cuStreamSynchronize(stream)) + for i in range(num_tokens): + token = host_buf[i] + value = Value.from_buffer_copy(token[0]) + for j in range(1, values_per_token): + assert token[j] == token[0] + yield value diff --git a/tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py b/tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py new file mode 100755 index 0000000000..765be68498 --- /dev/null +++ b/tests/unittest/kv_cache_manager_v2_tests/test_kv_cache_manager_v2.py @@ -0,0 +1,711 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import array +import functools +import gc +import itertools +import os +import random +import time +import unittest +from contextlib import contextmanager +from importlib.util import find_spec +from random import randbytes +from statistics import median +from typing import TYPE_CHECKING, Iterator, NamedTuple, cast + +if not TYPE_CHECKING and find_spec("kv_cache_manager_v2") is not None: + from kv_cache_manager_v2 import ( + AttentionLayerConfig, + BeamIndex, + BufferConfig, + CacheLevel, + CudaStream, + DiskCacheTierConfig, + GpuCacheTierConfig, + HostCacheTierConfig, + KVCacheManager, + KVCacheManagerConfig, + LayerGroupId, + LayerId, + TokenId, + TokenIdExt, + _KVCache, + ) + from kv_cache_manager_v2._block_radix_tree import traverse_post_order + from kv_cache_manager_v2._common import GPU_LEVEL, PageStatus, SlidingWindowSize + from kv_cache_manager_v2._exceptions import OutOfPagesError + from kv_cache_manager_v2._utils import ( + TemporaryCudaStream, + div_up, + init_cuda_once, + remove_if, + round_up, + typed_range, + unwrap_rawref, + ) +else: + from tensorrt_llm.runtime.kv_cache_manager_v2 import ( + AttentionLayerConfig, + BeamIndex, + BufferConfig, + CacheLevel, + CudaStream, + DiskCacheTierConfig, + GpuCacheTierConfig, + HostCacheTierConfig, + KVCacheManager, + KVCacheManagerConfig, + LayerGroupId, + LayerId, + TokenId, + TokenIdExt, + _KVCache, + ) + from tensorrt_llm.runtime.kv_cache_manager_v2._block_radix_tree import traverse_post_order + from tensorrt_llm.runtime.kv_cache_manager_v2._common import ( + GPU_LEVEL, + PageStatus, + SlidingWindowSize, + ) + from tensorrt_llm.runtime.kv_cache_manager_v2._exceptions import OutOfPagesError + from tensorrt_llm.runtime.kv_cache_manager_v2._utils import ( + TemporaryCudaStream, + div_up, + init_cuda_once, + remove_if, + round_up, + typed_range, + unwrap_rawref, + ) + +from dynamic_path_manager import DynamicPathManager +from parameterized import parameterized + +with DynamicPathManager(os.path.dirname(os.path.abspath(__file__)), clear_cache=False): + from fake_engine import FakeEngine, Role, Step + from kernels import enable_kernel_delay + +seed = int.from_bytes(os.urandom(8), "little") +print(f"seed: {seed}") +random.seed(seed) +DBG_PRINT = int(os.environ.get("DBG_PRINT", "0")) != 0 +PRINT_TIME = int(os.environ.get("PRINT_TIME", "0")) != 0 + + +@contextmanager +def ref_cycle_check_context(): + """Context manager for reference cycle check.""" + import gc + + gc.collect() + gc.garbage.clear() + gc.set_debug(gc.DEBUG_SAVEALL | gc.DEBUG_COLLECTABLE) + + def on_gc_event(phase, info): + # phase is "start" or "stop" + # info contains keys like: "generation", "collected", "uncollectable", "duration" + if phase == "stop": + collected = info.get("collected", 0) + uncollectable = info.get("uncollectable", 0) + if collected != 0 or uncollectable != 0: + import pdb + + pdb.set_trace() + assert collected == 0 and uncollectable == 0 + + gc.callbacks.append(on_gc_event) + try: + yield + finally: + gc.collect() + gc.callbacks.pop() + gc.set_debug(0) + + +def assert_no_ref_cycle(func): + """Decorator to wrap test methods with GC debugging context.""" + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with ref_cycle_check_context(): + result = func(self, *args, **kwargs) + return result + + return wrapper + + +class TestKVCacheManagerV2(unittest.TestCase): + engine: FakeEngine + cfg: KVCacheManagerConfig + manager: KVCacheManager + _token_id_gen: Iterator[int] + + def setUp(self) -> None: + init_cuda_once() + self._token_id_gen = itertools.count() + gc.collect() + gc.disable() + + def tearDown(self) -> None: + gc.enable() + if hasattr(self, "manager"): + del self.manager + + def next_token(self) -> TokenIdExt: + token_id = next(self._token_id_gen) + if token_id % 100 == 99: + return randbytes(32) + else: + return TokenId(token_id) + + def prepare( + self, + gpu_quota: int, + host_quota: int, + disk_quota: int, + num_layers: int, + window_size: SlidingWindowSize, + sink_tokens: int, + tokens_per_block: int = 32, + kv_buf_size: int = 8192, + ): + self._init_cfg( + tokens_per_block, + gpu_quota, + host_quota, + disk_quota, + num_layers, + window_size, + sink_tokens, + kv_buf_size, + ) + self.engine = FakeEngine(self.cfg) + self.manager = KVCacheManager(self.cfg) + + def _init_cfg( + self, + tokens_per_block: int, + gpu_quota: int, + host_quota: int, + disk_quota: int, + num_layers: int, + window_size: SlidingWindowSize, + sink_tokens: int, + kv_buf_size: int = 8192, + block_quant_buf_size: int | None = None, + ): + layer_buffers = [ + BufferConfig(role=Role.KEY, size=kv_buf_size), + BufferConfig(role=Role.VALUE, size=kv_buf_size), + ] + if block_quant_buf_size is not None: + layer_buffers.extend( + [ + BufferConfig(role=Role.KEY_BLOCK_QUANT, size=block_quant_buf_size), + BufferConfig(role=Role.VALUE_BLOCK_QUANT, size=block_quant_buf_size), + ] + ) + disk_path_candidates = ["/workspace/", "/tmp/nvidia-mps/", "/tmp"] + disk_path = next(p for p in disk_path_candidates if os.path.exists(p)) + cache_tiers = [ + GpuCacheTierConfig(quota=gpu_quota), + HostCacheTierConfig(quota=host_quota), + DiskCacheTierConfig(quota=disk_quota, path=disk_path), + ] + self.cfg = KVCacheManagerConfig( + tokens_per_block=tokens_per_block, + vocab_size=4096, + cache_tiers=[t for t in cache_tiers if t.quota > 0], + layers=[ + AttentionLayerConfig( + layer_id=layer_id, + buffers=layer_buffers, + sliding_window_size=window_size if layer_id % 2 == 0 else None, + num_sink_tokens=sink_tokens if layer_id % 2 == 0 else None, + ) + for layer_id in typed_range(LayerId(num_layers)) + ], + ) + + +class TestNoBatching(TestKVCacheManagerV2): + class Request(NamedTuple): + id: int + kv_cache: _KVCache + prompt: list[TokenIdExt] + decode_len: int + + def new_request( + self, req_id: int, lora_task_id: int | None, prompt_len: int, decode_len: int + ) -> Request: + prompt = [self.next_token() for _ in range(prompt_len)] + return self.Request( + req_id, self.manager.create_kv_cache(lora_task_id, prompt), prompt, decode_len + ) + + def run_request(self, req: Request, interval: int, refcheck: bool) -> float: + req_id, kv_cache, prompt, decode_len = req + assert kv_cache.status == _KVCache.Status.ACTIVE + stream = kv_cache.cuda_stream + tic = time.perf_counter() + # prefill + num_reused = kv_cache.num_committed_tokens + # workaround a mypyc bug: exception in property setter is not propagated + # kv_cache.capacity = round_up(len(prompt), interval) + if not kv_cache.resize(round_up(len(prompt), interval)): + raise OutOfPagesError("Not enough pages in GPU memory") + capacity = kv_cache.capacity + history = prompt[:num_reused] + input = prompt[num_reused:] + if refcheck: + self.engine.execute([Step(kv_cache, input, history)], stream) + if input: + kv_cache.commit(input) + history.extend(input) + # decode + for _ in range(decode_len): + required_capacity = len(history) + 1 + if required_capacity > capacity: + kv_cache.commit(history[kv_cache.history_length :]) + # workaround a mypyc bug: exception in property setter is not propagated + # kv_cache.capacity = round_up(required_capacity, interval) + if not kv_cache.resize(round_up(required_capacity, interval)): + raise OutOfPagesError("Not enough pages in GPU memory") + capacity = kv_cache.capacity + input_token = self.next_token() + if refcheck: + self.engine.execute([Step(kv_cache, [input_token], history)], stream) + history.append(input_token) + kv_cache.commit(history[kv_cache.history_length :]) + # last check + if refcheck: + self.engine.execute([Step(kv_cache, [], history)], stream) + toc = time.perf_counter() + time_taken = toc - tic + # print(f"Time taken: {time_taken} seconds") + return time_taken + + def run_naive( + self, + seq_len: int, + interval: int = 1, + refcheck: bool = True, + use_external_page_index_buf: bool = False, + ) -> float: + prompt_len = 1 + decode_len = seq_len - prompt_len + + req_id = 0 + lora_task_id = None + req0 = self.new_request(req_id, lora_task_id, prompt_len, decode_len) + if use_external_page_index_buf: + max_num_blocks = div_up(seq_len, self.cfg.tokens_per_block) + num_layer_groups = len(self.manager.layer_grouping) + page_indices = [ + array.array("i", [-1]) * max_num_blocks for _ in range(num_layer_groups) + ] + for id in range(num_layer_groups): + req0.kv_cache.set_page_index_buf( + BeamIndex(0), LayerGroupId(id), memoryview(page_indices[id]) + ) + with TemporaryCudaStream([]) as s: + stream = cast(CudaStream, s.handle) + kv_cache = req0.kv_cache + success = kv_cache.resume(stream) + assert success + time_taken = self.run_request(req0, interval, refcheck) + + s.take_finish_event().synchronize() + kv_cache.close() + self.manager.clear_reusable_blocks() + return time_taken + + @parameterized.expand([(False,), (True,)]) + def test_shrink_capacity(self, use_external_page_index_buf: bool) -> None: + self.prepare(32 << 20, 32 << 20, 1 << 30, 36, 128, 1, kv_buf_size=32768) + seq_len = 32 * 10 + req0 = self.new_request(0, None, 32, seq_len - 32) + if use_external_page_index_buf: + max_num_blocks = div_up(seq_len, self.cfg.tokens_per_block) + num_layer_groups = len(self.manager.layer_grouping) + page_indices = [ + array.array("i", [-1]) * max_num_blocks for _ in range(num_layer_groups) + ] + for id in range(num_layer_groups): + req0.kv_cache.set_page_index_buf( + BeamIndex(0), LayerGroupId(id), memoryview(page_indices[id]) + ) + with TemporaryCudaStream([]) as s: + stream = cast(CudaStream, s.handle) + kv_cache = req0.kv_cache + success = kv_cache.resume(stream) + assert success + success = kv_cache.resize(seq_len) + assert success + for capacity in range(seq_len, len(req0.prompt), -1): + success = kv_cache.resize(capacity) + assert success + s.take_finish_event() + kv_cache.close() + + def test_small_quota(self) -> None: + self.prepare(5619712, 0, 0, 8, None, 0) + assert self.manager.get_quota(cast(CacheLevel, GPU_LEVEL)) >= 5619712 + + # @assert_no_ref_cycle + def test_sol_mem_utilization(self) -> None: + self.prepare(32 << 20, 32 << 20, 1 << 30, 36, 128, 1, kv_buf_size=32768) + # if we have n blocks, we need 8192*2*18*(1+5+n) bytes of memory. For the (1+5+n), 1 is for sink + # blocks, 5 is for SWA (window=128), n is for full attention. + max_seq_len = 32 * 22 # 23 blocks will require more than 32MB memory + seq_len = max_seq_len + + # create a request and suspend it. It shall not consume any GPU memory after suspend. + req0 = self.new_request(0, None, 256, seq_len - 256) + with TemporaryCudaStream([]) as s: + stream = cast(CudaStream, s.handle) + success = req0.kv_cache.resume(stream) + assert success + self.run_request(req0, 32, False) + s.take_finish_event() + req0.kv_cache.suspend() + + # run another request that will take all the GPU memory + req1 = self.new_request(0, None, 256, seq_len - 256) + with TemporaryCudaStream([]) as s: + stream = cast(CudaStream, s.handle) + success = req1.kv_cache.resume(stream) + assert success + self.run_request(req1, 1, True) + s.take_finish_event() + + req1.kv_cache.close() + req0.kv_cache.close() + + # run another longer request and expect OutOfPagesError + # This also tests eviction to disk. + self.assertRaises(OutOfPagesError, lambda: self.run_naive(seq_len + 1, 1, False)) + + @parameterized.expand([(1,), (2,), (4,)]) + # @assert_no_ref_cycle + def test_cache_reuse(self, num_reusable_requests: int) -> None: + self.prepare(32 << 20, 32 << 20, 1 << 30, 36, 128, 1, kv_buf_size=32768) + # if we have n blocks, we need 8192*2*18*(1+5+n) bytes of memory. For the (1+5+n), 1 is for sink + # blocks, 5 is for SWA (window=128), n is for full attention. + max_seq_len = 32 * 22 # 23 blocks will require more than 32MB memory + seq_len = max_seq_len + + req_id_gen = itertools.count() + reusable_requests = [] + with TemporaryCudaStream([]) as s: + stream = cast(CudaStream, s.handle) + for _ in range(num_reusable_requests): + req = self.new_request(next(req_id_gen), None, 256, seq_len - 256) + reusable_requests.append(req) + success = req.kv_cache.resume(stream) + assert success + self.run_request(req, 32, True) + req.kv_cache.close() + s.take_finish_event() + + for root_block in self.manager._radix_tree.next.values(): + for block0 in root_block.next.values(): + for block in traverse_post_order(block0): + for page in block.storage: + if page is not None: + assert unwrap_rawref(page).status == PageStatus.DROPPABLE + + req0 = reusable_requests[0] + prompt1 = req0.kv_cache._committed_tokens[: (seq_len // 2 - 7)] + # request id must be same as req0 because we wrote it into the kv cache. + req1 = self.Request( + next(req_id_gen), + self.manager.create_kv_cache(None, prompt1), + prompt1, + seq_len - len(prompt1), + ) + assert req1.kv_cache.num_committed_tokens == len(prompt1) + with TemporaryCudaStream([]) as s: + stream = cast(CudaStream, s.handle) + success = req1.kv_cache.resume(stream) + assert success + self.run_request(req1, 32, True) + s.take_finish_event() + req1.kv_cache.close() + + self.manager.clear_reusable_blocks() + + @parameterized.expand([(False,), (True,)]) + # @assert_no_ref_cycle + def test_naive(self, use_external_page_index_buf: bool) -> None: + self.prepare(256 << 20, 256 << 20, 1 << 30, 36, 128, 48) + self.run_naive(512, 1, True, use_external_page_index_buf) + + @parameterized.expand([(2**i, False) for i in range(12)]) + # @parameterized.expand([(32, True)]) + # @assert_no_ref_cycle + def test_naive_perf(self, interval, profile: bool) -> None: + if not PRINT_TIME: + self.skipTest("Skipping perf test") + self.prepare(256 << 20, 256 << 20, 1 << 30, 36, 128, 48) + seq_len = 10240 + self.run_naive(seq_len, interval, False) # warm up for numba jit + profiler = None + if profile: + import cProfile + + profiler = cProfile.Profile() + profiler.enable() + time_taken = [ + self.run_naive(seq_len, interval, False) for _ in range(11 if profiler is None else 1) + ] + median_time_taken = median(time_taken) + if PRINT_TIME: + print( + f"Throughput: {round(seq_len / median_time_taken)} tokens/sec for interval {interval}" + ) + if profiler is not None: + profiler.disable() + profiler.print_stats(sort="cumtime") + profiler.dump_stats("profiler.prof") + + +class TestBatching(TestKVCacheManagerV2): + num_requests: int + avg_length: int + past_sequences: list[list[TokenIdExt]] + seq_len_dict: dict[_KVCache, int] + batch: list[Step] + suspended: list[Step] + num_created: int + num_finished: int + req_id_gen: Iterator[int] + acc_num_prompt_tokens: int + acc_num_decode_tokens: int + interval: int + enable_reuse: bool + + def setUp(self) -> None: + super().setUp() + self.past_sequences = list[list[TokenIdExt]]() + self.seq_len_dict = dict[_KVCache, int]() + self.batch = list[Step]() + self.suspended = list[Step]() + self.num_finished = 0 + self.num_created = 0 + self.req_id_gen = itertools.count() + self.acc_num_prompt_tokens = 0 + self.acc_num_decode_tokens = 0 + self.enable_reuse = False + + def gen_request(self) -> Step: + if self.num_created >= self.num_requests: + raise ValueError("Too many requests created") + + token_id_gen = cast(Iterator[TokenId], self._token_id_gen) + + def gen_length() -> int: + return random.randint(int(self.avg_length * 0.6), int(self.avg_length * 1.4)) + + if self.enable_reuse: + if len(self.past_sequences) >= 32 and random.random() < 0.2: + # continued multi-round dialog + prompt = random.choice(self.past_sequences) + [ + next(token_id_gen) for _ in range(gen_length()) + ] + else: + # new dialog + if len(self.past_sequences) < 32 or random.random() < 0.5: + # completely new prompt + prompt = [next(token_id_gen) for _ in range(gen_length())] + else: + # with reused tokens + reused = random.choice(self.past_sequences) + prompt = reused[: random.randint(0, min(gen_length(), len(reused)))] + [ + next(token_id_gen) for _ in range(gen_length()) + ] + else: + prompt = [next(token_id_gen) for _ in range(gen_length())] + decode_len = gen_length() + lora_task_id = None + kv_cache = self.manager.create_kv_cache( + lora_task_id, prompt[:-1] if self.enable_reuse else None, id=next(self.req_id_gen) + ) + DBG_PRINT and print( # type: ignore[arg-type] + f"created {kv_cache.id} with {kv_cache.num_committed_tokens} tokens reused" + ) + history = prompt[: kv_cache.num_committed_tokens] + input = prompt[kv_cache.num_committed_tokens :] + seq_len = len(prompt) + decode_len + self.seq_len_dict[kv_cache] = seq_len + self.num_created += 1 + assert input + self.acc_num_prompt_tokens += len(prompt) + self.acc_num_decode_tokens += decode_len + return Step(kv_cache, input, history) + + def update_batch(self, stream: CudaStream) -> None: + for s in self.batch: + assert s.input + if self.enable_reuse: + s.kv_cache.commit(s.input) + else: + s.kv_cache.history_length += len(s.input) + s.history.extend(s.input) + s.input.clear() + # remove finished requests first + removed = remove_if( + self.batch, + lambda step: len(step.history) >= self.seq_len_dict[step.kv_cache], + ) + for kv_cache, _, _ in removed: + seq_len = self.seq_len_dict[kv_cache] + if seq_len < self.avg_length * 3: + self.past_sequences.append(kv_cache._committed_tokens[:seq_len]) + kv_cache.close() + self.seq_len_dict.pop(kv_cache) + self.num_finished += 1 + # fill input for remaining requests and increase capacity for them + token_id_gen = cast(Iterator[TokenId], self._token_id_gen) + for s in self.batch: + assert not s.input + length = min(self.interval, self.seq_len_dict[s.kv_cache] - len(s.history)) + s.input.extend(next(token_id_gen) for _ in range(length)) + for i in itertools.count(): + if i >= len(self.batch): + break + s = self.batch[i] + while i < len(self.batch) and not s.kv_cache.resize( + len(s.history) + len(s.input), None + ): + last = self.batch.pop() + DBG_PRINT and print(f"suspending {last.kv_cache.id}") # type: ignore[arg-type] + last.kv_cache.suspend() + self.suspended.append(last) + + # try to add new requests + suspended = self.suspended + while suspended or self.num_created < self.num_requests: + if not suspended: + assert self.num_created < self.num_requests + suspended.append(self.gen_request()) + if suspended: + step = suspended[-1] + kv_cache = step.kv_cache + ok = kv_cache.resume(stream) + if ( + ok + and not self.enable_reuse + and kv_cache._commit_state == _KVCache.CommitState.ALLOWED + ): + kv_cache.stop_committing() + ok = ok and kv_cache.resize(len(step.history) + len(step.input), None) + if ok: + DBG_PRINT and print(f"activating {step.kv_cache.id}") # type: ignore[arg-type] + self.batch.append(suspended.pop()) + else: + if kv_cache.status == _KVCache.Status.ACTIVE: + kv_cache.suspend() + break + + DBG_PRINT and print( # type: ignore[arg-type] + f"update_batch: found {len(removed)} finished requests, now with {len(self.batch)} requests" + ) + + @parameterized.expand( + [ + (1000, 1000, 1024, True, 32, 32), + (1000, 1000, 1024, True, 1, 32), + (10000, 1000, 1024, True, 32, 32), + (100, 100, 128, False, 1, 128), + (100, 100, 128, False, 4, 64), + ] + ) + # @assert_no_ref_cycle + def test_inflight_batching( + self, + num_requests: int, + avg_length: int, + gpu_quota_mb: int, + skip_execution: bool, + interval: int, + tokens_per_block: int, + ): + self.prepare( + gpu_quota_mb << 20, 4 << 30, 0 << 30, 36, 128, 0, tokens_per_block=tokens_per_block + ) + self.num_requests = num_requests + self.avg_length = avg_length + self.interval = interval + profile = False + profiler = None + if profile: + import cProfile + + profiler = cProfile.Profile() + profiler.enable() + tic = time.perf_counter() + with TemporaryCudaStream([]) as s, enable_kernel_delay(): + stream = cast(CudaStream, s.handle) + i = itertools.count() + self.update_batch(stream) + while self.num_finished < self.num_requests: + DBG_PRINT and print( # type: ignore[arg-type] + f"Executing batch {next(i)} with size {len(self.batch)}" + ) + assert self.batch + if not skip_execution: + self.engine.execute(self.batch, stream) + self.update_batch(stream) + toc = time.perf_counter() + if profiler is not None: + profiler.disable() + profiler.print_stats(sort="cumtime") + profiler.dump_stats("profiler.prof") + if DBG_PRINT or PRINT_TIME: + print( + f"Time taken: {toc - tic} seconds (num_prompt_tokens: {self.acc_num_prompt_tokens}, " + f"num_decode_tokens: {self.acc_num_decode_tokens})" + ) + s.take_finish_event().synchronize() + + +class TestDisagg(TestKVCacheManagerV2): + @parameterized.expand([512]) + # @assert_no_ref_cycle + def test_disagg(self, prompt_len: int) -> None: + self.prepare(128 << 20, 128 << 20, 1 << 30, 36, 128, 0) + lora_task_id = None + prompt = [self.next_token() for _ in range(prompt_len)] + kv_cache = self.manager.create_kv_cache(lora_task_id, prompt) + assert kv_cache.num_committed_tokens == 0 + with TemporaryCudaStream([]) as stream: + success = kv_cache.resume(cast(CudaStream, stream.handle)) + assert success + success = kv_cache.resize(prompt_len, prompt_len) + assert success + + def transfer() -> None: + return None + + transfer() + kv_cache.commit(prompt) + kv_cache.close() + stream.take_finish_event().synchronize() + + +if __name__ == "__main__": + unittest.main()