[TRTLLM-7738][feat] Adding implementation of KVCacheManagerV2 (#10736)

Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>

KVCacheManagerV2 is a new python-based implementation of the KV cache manager, featuring cleaner API, better abstraction and better code quality without the accumulated legacy.
This commit is contained in:
Yao Yao 2026-01-24 17:48:39 +08:00 committed by GitHub
parent 9fcc93ea7b
commit 6f07fa81d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 9442 additions and 17 deletions

2
.gitignore vendored
View File

@ -52,6 +52,8 @@ tensorrt_llm/pg_utils_bindings.*.so
tensorrt_llm/flash_mla/
tensorrt_llm/flash_mla_cpp_tllm.*.so
tensorrt_llm/flash_mla_cpp_tllm.pyi
tensorrt_llm/runtime/kv_cache_manager_v2/**/*.so
**/*__mypyc*.so
tensorrt_llm/scripts
*docs/cpp_docs*
*docs/source/_cpp_gen*

View File

@ -36,6 +36,8 @@ set(SRCS
kvCacheManager.cpp
kvCacheEventManager.cpp
kvCacheTransferManager.cpp
kvCacheManagerV2Utils.cpp
kvCacheManagerV2Utils.cu
llmRequest.cpp
logitsPostProcessor.cpp
loraBuffers.cpp

View File

@ -0,0 +1,163 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h"
#include "tensorrt_llm/common/logger.h"
#include <cassert>
#include <cstdio>
#include <cuda.h>
#include <fcntl.h>
#include <memory>
#include <unistd.h>
#include <vector>
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{
template <typename Func>
bool loopedReadWrite(Func&& func, ssize_t size) noexcept
{
ssize_t count = 0;
while (count < size)
{
ssize_t bytes = func(count);
if (bytes <= 0)
{
if (errno == EINTR)
{
continue; // Retry on interrupt
}
TLLM_LOG_ERROR("Disk read/write failed: %s\n", strerror(errno));
return false;
}
count += bytes;
}
assert(count == size);
return true;
}
bool writeAll(int fd, ssize_t pos, void const* data, ssize_t size) noexcept
{
return loopedReadWrite([=](ssize_t finished)
{ return pwrite(fd, static_cast<std::byte const*>(data) + finished, size - finished, pos + finished); },
size);
}
bool readAll(int fd, ssize_t pos, void* data, ssize_t size) noexcept
{
return loopedReadWrite([=](ssize_t finished)
{ return pread(fd, static_cast<std::byte*>(data) + finished, size - finished, pos + finished); },
size);
}
template <typename DstAddr, typename SrcAddr>
struct UserData
{
std::vector<Task<DstAddr, SrcAddr>> tasks;
ssize_t numBytes;
};
CUDA_CB void hostFnDiskToDiskCopy(void* userData) noexcept
{
// @TODO: enable multi-threading with a thread pool
using Data = UserData<DiskAddress, DiskAddress>;
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
std::vector<std::byte> buffer(data->numBytes);
bool success = true;
for (auto const& t : data->tasks)
{
success = success && readAll(t.src.fd, t.src.pos, buffer.data(), data->numBytes);
success = success && writeAll(t.dst.fd, t.dst.pos, buffer.data(), data->numBytes);
}
if (!success)
{
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnDiskToDiskCopy failed.\n");
}
}
CUDA_CB void hostFnDiskToHostCopy(void* userData) noexcept
{
// @TODO: enable multi-threading with a thread pool
using Data = UserData<MemAddress, DiskAddress>;
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
bool success = true;
for (auto const& t : data->tasks)
{
success = success && readAll(t.src.fd, t.src.pos, reinterpret_cast<void*>(t.dst), data->numBytes);
}
if (!success)
{
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnDiskToHostCopy failed.\n");
}
}
CUDA_CB void hostFnHostToDiskCopy(void* userData) noexcept
{
// @TODO: enable multi-threading with a thread pool
using Data = UserData<DiskAddress, MemAddress>;
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
bool success = true;
for (auto const& t : data->tasks)
{
success = success && writeAll(t.dst.fd, t.dst.pos, reinterpret_cast<void const*>(t.src), data->numBytes);
}
if (!success)
{
TLLM_LOG_ERROR("[kvCacheManagerV2Utils] hostFnHostToDiskCopy failed.\n");
}
}
CUDA_CB void hostFnHostToHostCopy(void* userData) noexcept
{
// @TODO: enable multi-threading with a thread pool
using Data = UserData<MemAddress, MemAddress>;
auto const data = std::unique_ptr<Data>(static_cast<Data*>(userData));
for (auto const& t : data->tasks)
{
memcpy(reinterpret_cast<void*>(t.dst), reinterpret_cast<void const*>(t.src), data->numBytes);
}
}
CUresult copyDiskToDisk(std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
{
using Data = UserData<DiskAddress, DiskAddress>;
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
return cuLaunchHostFunc(stream, hostFnDiskToDiskCopy, data.release());
}
CUresult copyDiskToHost(std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
{
using Data = UserData<MemAddress, DiskAddress>;
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
return cuLaunchHostFunc(stream, hostFnDiskToHostCopy, data.release());
}
CUresult copyHostToDisk(std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
{
using Data = UserData<DiskAddress, MemAddress>;
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
return cuLaunchHostFunc(stream, hostFnHostToDiskCopy, data.release());
}
CUresult copyHostToHost(std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept
{
using Data = UserData<MemAddress, MemAddress>;
auto data = std::make_unique<Data>(Data{std::move(tasks), numBytes});
return cuLaunchHostFunc(stream, hostFnHostToHostCopy, data.release());
}
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -0,0 +1,182 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kvCacheManagerV2Utils.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <algorithm>
#include <array>
#include <cassert>
#include <cuda_runtime.h>
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{
using Grain = uint4;
constexpr uint32_t ctaSize = 128;
constexpr uint32_t nbBufs = 4;
constexpr uint32_t grainBytes = sizeof(Grain);
using MMTask = Task<MemAddress, MemAddress>;
__device__ __host__ inline uint32_t divUp(uint32_t a, uint32_t b)
{
return (a + b - 1) / b;
}
template <uint32_t N>
__global__ void batchedCopy(std::array<MMTask, N> const __grid_constant__ tasks, uint32_t nbBytes)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.launch_dependents;\n");
#endif
assert(nbBytes % sizeof(Grain) == 0);
__shared__ Grain data[nbBufs][ctaSize];
uint32_t const nbTasks = gridDim.y;
assert(nbTasks <= N);
auto const& task = tasks[blockIdx.y];
uint32_t const nbSplits = gridDim.x;
uint32_t const idxSplit = blockIdx.x;
uint32_t const tid = threadIdx.x;
constexpr uint32_t bytesPerIter = grainBytes * ctaSize;
uint32_t const totalIters = divUp(nbBytes, bytesPerIter);
uint32_t const maxItersPerCta = divUp(totalIters, nbSplits);
uint32_t const idxGrainBeg = ctaSize * maxItersPerCta * idxSplit + tid;
uint32_t const idxGrainEnd = std::min(idxGrainBeg + ctaSize * maxItersPerCta, nbBytes / grainBytes);
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.wait;\n");
#endif
for (uint32_t i = 0; i < maxItersPerCta + nbBufs; i++)
{
uint32_t const idxBuf = i % nbBufs;
if (i >= nbBufs)
{
uint32_t const stIter = i - nbBufs;
assert(idxBuf == (stIter % nbBufs));
Grain const& src = data[idxBuf][tid];
uint32_t const idxGrain = idxGrainBeg + ctaSize * stIter;
Grain& dst = reinterpret_cast<Grain*>(task.dst)[idxGrain];
asm volatile("cp.async.wait_group %0;\n" ::"n"(nbBufs - 1) : "memory");
if (idxGrain < idxGrainEnd)
{
dst = src;
}
}
uint32_t const ldIter = i;
Grain* const dst = &data[idxBuf][tid];
uint32_t const idxGrain = idxGrainBeg + ctaSize * ldIter;
Grain const* const src = &reinterpret_cast<Grain const*>(task.src)[idxGrain];
if (idxGrain < idxGrainEnd)
{
uint32_t const size = grainBytes;
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)),
"l"(src), "n"(grainBytes), "r"(size)
: "memory");
}
asm volatile("cp.async.commit_group;\n" : : : "memory");
}
}
template <uint32_t N>
CUresult launchBatchedCopyImpl(
bool lowBandwidth, MMTask const* tasks, uint32_t nbTasks, uint32_t nbBytes, cudaStream_t stream)
{
TLLM_CHECK(nbTasks <= N);
TLLM_CHECK_WITH_INFO(
nbBytes % sizeof(Grain) == 0, "Not implemented case: nbBytes = %d must be a multiple of 16.", nbBytes);
std::array<MMTask, N> const* pTasks;
std::array<MMTask, N> tmp;
if (nbTasks < N)
{
std::copy_n(tasks, nbTasks, tmp.begin());
pTasks = &tmp;
}
else
{
pTasks = reinterpret_cast<std::array<MMTask, N> const*>(tasks);
}
uint32_t const nbSplits = lowBandwidth ? 1 : divUp(nbBytes, grainBytes * ctaSize * 2);
void* args[] = {(void*) pTasks, (void*) &nbBytes};
static CUkernel const kernel = [] -> CUkernel
{
cudaKernel_t kernel = nullptr;
TLLM_CUDA_CHECK(cudaGetKernel(&kernel, reinterpret_cast<void const*>(&batchedCopy<N>)));
return kernel;
}();
return common::CUDADriverWrapper::getInstance()->cuLaunchKernel(reinterpret_cast<CUfunction>(kernel), nbSplits,
nbTasks, 1, // gridDimX, gridDimY, gridDimZ
ctaSize, 1, 1, // blockDimX, blockDimY, blockDimZ
0, // sharedMemBytes
stream, args, nullptr);
}
// When bandwidth is low, e.g. when host memory is involved, we avoid splitting as fewer CTAs should be enough to
// saturate the bandwidth.
CUresult launchBatchedCopy(bool lowBandwidth, std::vector<MMTask> const& tasks, uint32_t nbBytes, cudaStream_t stream)
{
constexpr uint32_t maxN = 256;
uint32_t const nbWholeBatches = tasks.size() / maxN;
for (uint32_t i = 0; i < nbWholeBatches; i++)
{
CUresult const err = launchBatchedCopyImpl<maxN>(lowBandwidth, tasks.data() + maxN * i, maxN, nbBytes, stream);
if (err != CUDA_SUCCESS)
{
return err;
}
}
{
auto const* const pTasks = tasks.data() + maxN * nbWholeBatches;
auto const batchSize = tasks.size() % maxN;
if (batchSize == 0)
{
return CUDA_SUCCESS;
}
if (batchSize > maxN / 2)
{
return launchBatchedCopyImpl<maxN>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
}
if (batchSize > maxN / 4)
{
return launchBatchedCopyImpl<maxN / 2>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
}
if (batchSize > maxN / 8)
{
return launchBatchedCopyImpl<maxN / 4>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
}
return launchBatchedCopyImpl<maxN / 8>(lowBandwidth, pTasks, batchSize, nbBytes, stream);
}
}
CUresult copyHostToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
{
return launchBatchedCopy(true, tasks, numBytes, stream);
}
CUresult copyDeviceToHost(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
{
return launchBatchedCopy(true, tasks, numBytes, stream);
}
CUresult copyDeviceToDevice(std::vector<MMTask> const& tasks, ssize_t numBytes, CUstream stream) noexcept
{
return launchBatchedCopy(false, tasks, numBytes, stream);
}
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -0,0 +1,51 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <cuda.h>
#include <vector>
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{
struct DiskAddress
{
int fd;
ssize_t pos;
};
using MemAddress = std::uintptr_t;
template <typename DstAddr, typename SrcAddr>
struct Task
{
DstAddr dst;
SrcAddr src;
};
CUresult copyDiskToDisk(std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept;
CUresult copyDiskToHost(std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept;
CUresult copyHostToDisk(std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept;
CUresult copyHostToHost(std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, CUstream stream) noexcept;
CUresult copyHostToDevice(
std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, CUstream stream) noexcept;
CUresult copyDeviceToHost(
std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, CUstream stream) noexcept;
CUresult copyDeviceToDevice(
std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, CUstream stream) noexcept;
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -10,6 +10,7 @@ set(SRCS
batch_manager/cacheTransceiver.cpp
batch_manager/kvCacheConnector.cpp
batch_manager/kvCacheManager.cpp
batch_manager/kvCacheManagerV2Utils.cpp
batch_manager/llmRequest.cpp
common/tllmExceptions.cpp
executor/bindings.cpp

View File

@ -0,0 +1,108 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kvCacheManagerV2Utils.h"
#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h"
#include <nanobind/nanobind.h>
#include <nanobind/stl/vector.h>
namespace nb = nanobind;
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{
void KVCacheManagerV2UtilsBindings::initBindings(nb::module_& module)
{
// Bind DiskAddress struct
nb::class_<DiskAddress>(module, "DiskAddress")
.def(nb::init<int, ssize_t>(), nb::arg("fd"), nb::arg("pos"))
.def_rw("fd", &DiskAddress::fd)
.def_rw("pos", &DiskAddress::pos);
// Bind Task template instantiations
nb::class_<Task<DiskAddress, DiskAddress>>(module, "DiskToDiskTask")
.def(nb::init<DiskAddress, DiskAddress>(), nb::arg("dst"), nb::arg("src"))
.def_rw("dst", &Task<DiskAddress, DiskAddress>::dst)
.def_rw("src", &Task<DiskAddress, DiskAddress>::src);
nb::class_<Task<MemAddress, DiskAddress>>(module, "DiskToHostTask")
.def(nb::init<MemAddress, DiskAddress>(), nb::arg("dst"), nb::arg("src"))
.def_rw("dst", &Task<MemAddress, DiskAddress>::dst)
.def_rw("src", &Task<MemAddress, DiskAddress>::src);
nb::class_<Task<DiskAddress, MemAddress>>(module, "HostToDiskTask")
.def(nb::init<DiskAddress, MemAddress>(), nb::arg("dst"), nb::arg("src"))
.def_rw("dst", &Task<DiskAddress, MemAddress>::dst)
.def_rw("src", &Task<DiskAddress, MemAddress>::src);
nb::class_<Task<MemAddress, MemAddress>>(module, "MemToMemTask")
.def(nb::init<MemAddress, MemAddress>(), nb::arg("dst"), nb::arg("src"))
.def_rw("dst", &Task<MemAddress, MemAddress>::dst)
.def_rw("src", &Task<MemAddress, MemAddress>::src);
// Bind copy functions
module.def(
"copy_disk_to_disk",
[](std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDiskToDisk(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
"Copy data from disk to disk using CUDA host function");
module.def(
"copy_disk_to_host",
[](std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDiskToHost(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
"Copy data from disk to host using CUDA host function");
module.def(
"copy_host_to_disk",
[](std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyHostToDisk(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
"Copy data from host to disk using CUDA host function");
module.def(
"copy_host_to_host",
[](std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyHostToHost(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
"Copy data from host to host using CUDA host function");
module.def(
"copy_host_to_device",
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyHostToDevice(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
"Copy data from host to device using CUDA kernels");
module.def(
"copy_device_to_host",
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDeviceToHost(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
"Copy data from device to host using CUDA kernels");
module.def(
"copy_device_to_device",
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDeviceToDevice(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
nb::arg("tasks"), nb::arg("num_bytes"), nb::arg("stream"), nb::call_guard<nb::gil_scoped_release>(),
"Copy data from device to device using CUDA kernels");
}
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -0,0 +1,31 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nanobind/nanobind.h>
namespace nb = nanobind;
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{
class KVCacheManagerV2UtilsBindings
{
public:
static void initBindings(nb::module_& module);
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -38,6 +38,7 @@
#include "tensorrt_llm/nanobind/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/nanobind/batch_manager/kvCacheConnector.h"
#include "tensorrt_llm/nanobind/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/nanobind/batch_manager/kvCacheManagerV2Utils.h"
#include "tensorrt_llm/nanobind/batch_manager/llmRequest.h"
#include "tensorrt_llm/nanobind/common/tllmExceptions.h"
#include "tensorrt_llm/nanobind/executor/bindings.h"
@ -131,6 +132,8 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
auto mInternalBatchManagerKvCacheV2Utils
= mInternalBatchManager.def_submodule("kv_cache_manager_v2_utils", "KV Cache Manager V2 Utils bindings");
auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings");
auto mExceptions = m.def_submodule("exceptions", "Exceptions internal bindings");
@ -502,6 +505,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager);
tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager);
tb::CacheTransceiverBindings::initBindings(mInternalBatchManager);
tb::kv_cache_manager_v2::KVCacheManagerV2UtilsBindings::initBindings(mInternalBatchManagerKvCacheV2Utils);
auto mInternalAlgorithms = mInternal.def_submodule("algorithms", "Algorithms internal bindings");
tpb::algorithms::initBindings(mInternalAlgorithms);

View File

@ -315,5 +315,24 @@ struct type_caster<torch::ScalarType>
throw std::runtime_error("from_cpp for torch::ScalarType is not implemented");
}
};
template <>
class type_caster<CUstream>
{
public:
NB_TYPE_CASTER(CUstream, const_name("int"));
bool from_python([[maybe_unused]] handle src, uint8_t flags, cleanup_list* cleanup)
{
value = reinterpret_cast<CUstream>(PyLong_AsVoidPtr(src.ptr()));
return true;
return true;
}
static handle from_cpp(CUstream const& src, rv_policy /* policy */, cleanup_list* /* cleanup */)
{
return PyLong_FromVoidPtr(src);
}
};
} // namespace detail
} // namespace NB_NAMESPACE

View File

@ -10,6 +10,7 @@ set(SRCS
batch_manager/cacheTransceiver.cpp
batch_manager/kvCacheConnector.cpp
batch_manager/kvCacheManager.cpp
batch_manager/kvCacheManagerV2Utils.cpp
batch_manager/llmRequest.cpp
executor/bindings.cpp
executor/executor.cpp

View File

@ -0,0 +1,111 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kvCacheManagerV2Utils.h"
#include "tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
namespace py = pybind11;
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{
void KVCacheManagerV2UtilsBindings::initBindings(py::module_& module)
{
// Bind DiskAddress struct
py::class_<DiskAddress>(module, "DiskAddress")
.def(py::init<int, ssize_t>(), py::arg("fd"), py::arg("pos"))
.def_readwrite("fd", &DiskAddress::fd)
.def_readwrite("pos", &DiskAddress::pos);
// Bind Task template instantiations
py::class_<Task<DiskAddress, DiskAddress>>(module, "DiskToDiskTask")
.def(py::init<DiskAddress, DiskAddress>(), py::arg("dst"), py::arg("src"))
.def_readwrite("dst", &Task<DiskAddress, DiskAddress>::dst)
.def_readwrite("src", &Task<DiskAddress, DiskAddress>::src);
py::class_<Task<MemAddress, DiskAddress>>(module, "DiskToHostTask")
.def(py::init<MemAddress, DiskAddress>(), py::arg("dst"), py::arg("src"))
.def_readwrite("dst", &Task<MemAddress, DiskAddress>::dst)
.def_readwrite("src", &Task<MemAddress, DiskAddress>::src);
py::class_<Task<DiskAddress, MemAddress>>(module, "HostToDiskTask")
.def(py::init<DiskAddress, MemAddress>(), py::arg("dst"), py::arg("src"))
.def_readwrite("dst", &Task<DiskAddress, MemAddress>::dst)
.def_readwrite("src", &Task<DiskAddress, MemAddress>::src);
py::class_<Task<MemAddress, MemAddress>>(module, "MemToMemTask")
.def(py::init<MemAddress, MemAddress>(), py::arg("dst"), py::arg("src"))
.def_readwrite("dst", &Task<MemAddress, MemAddress>::dst)
.def_readwrite("src", &Task<MemAddress, MemAddress>::src);
// Bind copy functions
module.def(
"copy_disk_to_disk",
[](std::vector<Task<DiskAddress, DiskAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDiskToDisk(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from disk to disk using CUDA host function");
module.def(
"copy_disk_to_host",
[](std::vector<Task<MemAddress, DiskAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDiskToHost(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from disk to host using CUDA host function");
module.def(
"copy_host_to_disk",
[](std::vector<Task<DiskAddress, MemAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyHostToDisk(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from host to disk using CUDA host function");
module.def(
"copy_host_to_host",
[](std::vector<Task<MemAddress, MemAddress>> tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyHostToHost(std::move(tasks), numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from host to host using CUDA host function");
module.def(
"copy_host_to_device",
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyHostToDevice(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from host to device using CUDA kernels");
module.def(
"copy_device_to_host",
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDeviceToHost(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from device to host using CUDA kernels");
module.def(
"copy_device_to_device",
[](std::vector<Task<MemAddress, MemAddress>> const& tasks, ssize_t numBytes, uintptr_t stream) -> int
{ return copyDeviceToDevice(tasks, numBytes, reinterpret_cast<CUstream>(stream)); },
py::arg("tasks"), py::arg("num_bytes"), py::arg("stream"), py::call_guard<py::gil_scoped_release>(),
"Copy data from device to device using CUDA kernels");
}
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -0,0 +1,29 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <pybind11/pybind11.h>
namespace tensorrt_llm::batch_manager::kv_cache_manager_v2
{
class KVCacheManagerV2UtilsBindings
{
public:
static void initBindings(pybind11::module_& module);
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager_v2

View File

@ -33,6 +33,7 @@
#include "tensorrt_llm/pybind/batch_manager/cacheTransceiver.h"
#include "tensorrt_llm/pybind/batch_manager/kvCacheConnector.h"
#include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/pybind/batch_manager/kvCacheManagerV2Utils.h"
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/common/tllmExceptions.h"
#include "tensorrt_llm/pybind/executor/bindings.h"
@ -124,6 +125,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
auto mInternalRuntime = mInternal.def_submodule("runtime", "Runtime internal bindings");
auto mInternalTesting = mInternal.def_submodule("testing", "Testing internal bindings");
auto mInternalBatchManager = mInternal.def_submodule("batch_manager", "Batch manager internal bindings");
auto mInternalBatchManagerKvCacheV2Utils
= mInternalBatchManager.def_submodule("kv_cache_manager_v2_utils", "KV Cache Manager V2 Utils bindings");
auto mInternalThop = mInternal.def_submodule("thop", "Torch op internal bindings");
auto mExceptions = m.def_submodule("exceptions", "Exceptions internal bindings");
@ -490,6 +493,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
tb::kv_cache_manager::KVCacheManagerBindings::initBindings(mInternalBatchManager);
tb::BasePeftCacheManagerBindings::initBindings(mInternalBatchManager);
tb::CacheTransceiverBindings::initBindings(mInternalBatchManager);
tb::kv_cache_manager_v2::KVCacheManagerV2UtilsBindings::initBindings(mInternalBatchManagerKvCacheV2Utils);
auto mInternalAlgorithms = mInternal.def_submodule("algorithms", "Algorithms internal bindings");
tpb::algorithms::initBindings(mInternalAlgorithms);

View File

@ -78,3 +78,6 @@ apache-tvm-ffi==0.1.6 # used for reduce nvidia-cutlass-dsl host overhead
torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf
mistral-common==1.8.6
torchao>=0.14.1
cuda-core
llist
dynamic_path_manager

View File

@ -457,6 +457,43 @@ def generate_python_stubs_windows(binding_type: str, venv_python: Path,
(pkg_dir / stubgen).unlink()
def build_kv_cache_manager_v2(project_dir, venv_python, use_mypyc=False):
print("-- Building kv_cache_manager_v2...")
kv_cache_mgr_dir = project_dir / "tensorrt_llm/runtime/kv_cache_manager_v2"
runtime_dir = project_dir / "tensorrt_llm/runtime"
# Clean up any existing mypyc artifacts in runtime directory to prevent stale inclusion
# when switching from --mypyc to standard build
if not use_mypyc:
for so_file in runtime_dir.glob("*__mypyc*.so"):
print(f"Removing stale mypyc artifact: {so_file}")
so_file.unlink()
# Also clean up any .so files inside kv_cache_manager_v2
for so_file in kv_cache_mgr_dir.rglob("*.so"):
print(f"Removing stale artifact: {so_file}")
so_file.unlink()
# Build rawref
print("-- Building kv_cache_manager_v2 rawref extension...")
rawref_dir = kv_cache_mgr_dir / "rawref"
build_run(f'"{venv_python}" setup.py build_ext --inplace', cwd=rawref_dir)
if use_mypyc:
# Build mypyc
print("-- Building kv_cache_manager_v2 mypyc extensions...")
# setup_mypyc.py is in kv_cache_manager_v2 but executed from runtime dir
setup_mypyc = kv_cache_mgr_dir / "setup_mypyc.py"
build_run(f'"{venv_python}" "{setup_mypyc}" build_ext --inplace',
cwd=runtime_dir)
# Verify that the shared library was generated
if not list(runtime_dir.glob("*__mypyc*.so")):
raise RuntimeError(
"Failed to build kv_cache_manager_v2: no shared library generated."
)
def main(*,
build_type: str = "Release",
generator: str = "",
@ -487,7 +524,8 @@ def main(*,
skip_stubs: bool = False,
generate_fmha: bool = False,
no_venv: bool = False,
nvrtc_dynamic_linking: bool = False):
nvrtc_dynamic_linking: bool = False,
mypyc: bool = False):
if clean:
clean_wheel = True
@ -967,6 +1005,8 @@ def main(*,
nixl_root is not None or mooncake_root is not None,
binding_lib_file_name)
build_kv_cache_manager_v2(project_dir, venv_python, use_mypyc=mypyc)
if not skip_building_wheel:
if dist_dir is None:
dist_dir = project_dir / "build"
@ -988,6 +1028,15 @@ def main(*,
build_run(
f'\"{venv_python}\" -m build {project_dir} --skip-dependency-check {extra_wheel_build_args} --no-isolation --wheel --outdir "{dist_dir}"'
)
env = os.environ.copy()
if mypyc:
env["TRTLLM_ENABLE_MYPYC"] = "1"
else:
env["TRTLLM_ENABLE_MYPYC"] = "0"
build_run(
f'\"{venv_python}\" -m build {project_dir} --skip-dependency-check --no-isolation --wheel --outdir "{dist_dir}"',
env=env)
if install:
build_run(f"\"{sys.executable}\" -m pip install -e .[devel]")
@ -1135,6 +1184,9 @@ def add_arguments(parser: ArgumentParser):
"--nvrtc_dynamic_linking",
action="store_true",
help="Link against dynamic NVRTC libraries instead of static ones")
parser.add_argument("--mypyc",
action="store_true",
help="Compile kv_cache_manager_v2 with mypyc")
if __name__ == "__main__":

View File

@ -110,21 +110,43 @@ if on_windows:
]
else:
package_data = [
'bin/executorWorker', 'libs/libtensorrt_llm.so', 'libs/libth_common.so',
'bin/executorWorker',
'libs/libtensorrt_llm.so',
'libs/libth_common.so',
'libs/libnvinfer_plugin_tensorrt_llm.so',
'libs/libtensorrt_llm_ucx_wrapper.so', 'libs/libdecoder_attention_0.so',
'libs/libtensorrt_llm_nixl_wrapper.so', 'libs/nixl/**/*',
'libs/libtensorrt_llm_ucx_wrapper.so',
'libs/libdecoder_attention_0.so',
'libs/libtensorrt_llm_nixl_wrapper.so',
'libs/nixl/**/*',
'tensorrt_llm_transfer_agent_binding*.so',
'tensorrt_llm_transfer_agent_binding.pyi',
'libs/libtensorrt_llm_mooncake_wrapper.so', 'libs/ucx/**/*',
'libs/libpg_utils.so', 'libs/libdecoder_attention_1.so',
'libs/nvshmem/License.txt', 'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
'libs/nvshmem/nvshmem_transport_ibgda.so.103', 'bindings.*.so',
'deep_ep/LICENSE', 'deep_ep/*.py', 'deep_ep_cpp_tllm.*.so',
"include/**/*", 'deep_gemm/LICENSE', 'deep_gemm/include/**/*',
'deep_gemm/*.py', 'deep_gemm_cpp_tllm.*.so',
'scripts/install_tensorrt.sh', 'flash_mla/LICENSE', 'flash_mla/*.py',
'flash_mla_cpp_tllm.*.so'
'libs/libtensorrt_llm_mooncake_wrapper.so',
'libs/ucx/**/*',
'libs/libpg_utils.so',
'libs/libdecoder_attention_1.so',
'libs/nvshmem/License.txt',
'libs/nvshmem/nvshmem_bootstrap_uid.so.3',
'libs/nvshmem/nvshmem_transport_ibgda.so.103',
'bindings.*.so',
'deep_ep/LICENSE',
'deep_ep/*.py',
'deep_ep_cpp_tllm.*.so',
"include/**/*",
'deep_gemm/LICENSE',
'deep_gemm/include/**/*',
'deep_gemm/*.py',
'deep_gemm_cpp_tllm.*.so',
'scripts/install_tensorrt.sh',
'flash_mla/LICENSE',
'flash_mla/*.py',
'flash_mla_cpp_tllm.*.so',
'runtime/kv_cache_manager_v2/*.so',
'runtime/kv_cache_manager_v2/**/*.so',
'runtime/kv_cache_manager_v2/*.pyi',
'runtime/kv_cache_manager_v2/**/*.pyi',
'runtime/kv_cache_manager_v2/rawref/*.py',
'runtime/kv_cache_manager_v2/rawref/*.pyi',
'runtime/*__mypyc*.so',
]
package_data += [
@ -268,10 +290,16 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
# Skip .py files EXCEPT for generated C++ extension wrappers
# (deep_gemm, deep_ep, flash_mla Python files are generated during build)
if file.filename.endswith(".py"):
allowed_dirs = ("tensorrt_llm/deep_gemm/",
"tensorrt_llm/deep_ep/",
"tensorrt_llm/flash_mla/")
allowed_dirs = (
"tensorrt_llm/deep_gemm/", "tensorrt_llm/deep_ep/",
"tensorrt_llm/flash_mla/",
"tensorrt_llm/runtime/kv_cache_manager_v2/rawref/__init__.py"
)
if not any(file.filename.startswith(d) for d in allowed_dirs):
# Exclude all .py files in kv_cache_manager_v2 except rawref/__init__.py
if file.filename.startswith("tensorrt_llm/runtime/kv_cache_manager_v2/") and \
not file.filename.endswith("rawref/__init__.py"):
continue
continue
for filename_pattern in package_data:
@ -305,6 +333,38 @@ sanity_check()
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
# We use find_packages with a custom exclude filter to handle the mypyc compiled modules.
# We want to exclude the .py source files for modules that are compiled to .so.
# We exclude the kv_cache_manager_v2 package entirely from the source list,
# but explicitly add back the rawref subpackage (which is not compiled by mypyc).
# The .so and .pyi files for kv_cache_manager_v2 are added via package_data.
enable_mypyc = os.getenv("TRTLLM_ENABLE_MYPYC", "0") == "1"
if enable_mypyc:
packages = find_packages(exclude=[
"tensorrt_llm.runtime.kv_cache_manager_v2",
"tensorrt_llm.runtime.kv_cache_manager_v2.*",
]) + ["tensorrt_llm.runtime.kv_cache_manager_v2.rawref"]
exclude_package_data = {
"tensorrt_llm": [
"runtime/kv_cache_manager_v2/*.py",
"runtime/kv_cache_manager_v2/**/*.py"
],
"tensorrt_llm.runtime.kv_cache_manager_v2": ["*.py", "**/*.py"],
}
else:
packages = find_packages()
exclude_package_data = {}
# Remove mypyc shared objects from package_data to avoid packaging stale files
package_data = [
p for p in package_data if p not in [
'runtime/kv_cache_manager_v2/*.so',
'runtime/kv_cache_manager_v2/**/*.so', 'runtime/*__mypyc*.so'
]
]
# Ensure rawref is included
package_data.append('runtime/kv_cache_manager_v2/rawref/*.so')
# https://setuptools.pypa.io/en/latest/references/keywords.html
setup(
name='tensorrt_llm',
@ -318,7 +378,8 @@ setup(
author="NVIDIA Corporation",
url="https://github.com/NVIDIA/TensorRT-LLM",
download_url="https://github.com/NVIDIA/TensorRT-LLM/tags",
packages=find_packages(),
packages=packages,
exclude_package_data=exclude_package_data,
# TODO Add windows support for python bindings.
classifiers=[
"Development Status :: 4 - Beta",

View File

@ -12,6 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dynamic_path_manager import DynamicPathManager
# Add current directory to sys.path so kv_cache_manager_v2 can be imported as top-level package.
# This is required because when kv_cache_manager_v2 is compiled with mypyc, it is compiled as
# a top-level package (to avoid complex build paths), but at runtime it is used as a submodule.
# The compiled extension might try to import its submodules using absolute imports based on its
# compiled name.
with DynamicPathManager(os.path.dirname(os.path.abspath(__file__)),
clear_cache=False):
import kv_cache_manager_v2
from .enc_dec_model_runner import EncDecModelRunner
from .generation import SamplingConfig # autoflake: skip
from .generation import (ChatGLMGenerationSession, GenerationSession,
@ -52,4 +65,5 @@ __all__ = [
'EncDecModelRunner',
'MultimodalModelRunner',
'PYTHON_BINDINGS',
'kv_cache_manager_v2',
]

View File

@ -0,0 +1,51 @@
# ##################################################################################################
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ##################################################################################################
# Makefile for kv_cache_manager_v2
# Replaces build_mypyc.sh and rawref/build.sh
PYTHON ?= python3
RUNTIME_DIR ?= ..
.PHONY: all mypyc rawref clean clean_mypyc clean_rawref
# Default target
all: rawref mypyc
# Build the rawref C extension
rawref:
cd rawref && $(PYTHON) setup.py build_ext --inplace
# Build the mypyc extension
# Must be run from the parent directory (runtime) as setup_mypyc.py expects
# kv_cache_manager_v2/ prefixes in module names.
mypyc:
cd $(RUNTIME_DIR) && $(PYTHON) kv_cache_manager_v2/setup_mypyc.py build_ext --inplace
# Clean everything
clean: clean_rawref clean_mypyc
# Clean rawref build artifacts
clean_rawref:
cd rawref && rm -rf build
find rawref -name "*.so" -type f -delete
# Clean mypyc build artifacts
# Cleans build/ directory in runtime (parent) and .so files in this directory
clean_mypyc:
rm -rf $(RUNTIME_DIR)/build
find . -name "*.so" -type f ! -path "./rawref/*" -delete

View File

@ -0,0 +1,67 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import rawref
from ._block_radix_tree import gen_multi_modal_tokens
from ._common import (
NDEBUG,
CacheLevel,
CacheTier,
CudaStream,
LayerId,
MemAddress,
Priority,
TokenId,
TokenIdExt,
)
from ._config import (
AttentionLayerConfig,
BufferConfig,
CacheTierConfig,
DataRole,
DiskCacheTierConfig,
GpuCacheTierConfig,
HostCacheTierConfig,
KVCacheManagerConfig,
)
from ._core import BeamIndex, KVCacheManager, _KVCache
from ._life_cycle_registry import LayerGroupId, LifeCycleId
__all__ = [
"LifeCycleId",
"LayerGroupId",
"TokenId",
"TokenIdExt",
"KVCacheManager",
"_KVCache",
"BeamIndex",
"LayerId",
"Priority",
"CacheLevel",
"CacheTier",
"CudaStream",
"MemAddress",
"NDEBUG",
"KVCacheManagerConfig",
"AttentionLayerConfig",
"BufferConfig",
"DataRole",
"DiskCacheTierConfig",
"GpuCacheTierConfig",
"HostCacheTierConfig",
"CacheTierConfig",
"gen_multi_modal_tokens",
"rawref",
]

View File

@ -0,0 +1,213 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import array
import enum
from dataclasses import dataclass
from typing import (
Any,
Callable,
ClassVar,
Final,
Iterable,
Iterator,
NewType,
Protocol,
Sequence,
Type,
TypeAlias,
Union,
)
# From _common.py
NDEBUG: Final[int]
class CacheTier(enum.IntEnum):
GPU_MEM = 0
HOST_MEM = 1
DISK = 2
LifeCycleId = NewType("LifeCycleId", int)
LayerGroupId: TypeAlias = LifeCycleId
CacheLevel = NewType("CacheLevel", int)
TokenId = NewType("TokenId", int)
TokenIdExt = Union[TokenId, bytes]
LayerId = NewType("LayerId", int)
CudaStream = NewType("CudaStream", int)
BeamIndex = NewType("BeamIndex", int)
MemAddress = NewType("MemAddress", int)
Priority = NewType("Priority", int)
# From _config.py
DataRole = NewType("DataRole", str)
class CacheTierConfig(Protocol):
quota: int
@property
def tier(self) -> CacheTier: ...
def assert_valid(self) -> None: ...
@dataclass(slots=True)
class GpuCacheTierConfig:
quota: int
@property
def tier(self) -> CacheTier: ...
def assert_valid(self) -> None: ...
@dataclass(slots=True)
class HostCacheTierConfig:
quota: int
@property
def tier(self) -> CacheTier: ...
def assert_valid(self) -> None: ...
@dataclass(slots=True)
class DiskCacheTierConfig:
quota: int
path: str
@property
def tier(self) -> CacheTier: ...
def assert_valid(self) -> None: ...
@dataclass(slots=True)
class BufferConfig:
role: DataRole
size: int
@dataclass(slots=True)
class HelixConfig:
helix_group_size: int
helix_gpu_rank: int
helix_shard_size: int
shared_comm_port: int
@dataclass(slots=True)
class AttentionLayerConfig:
layer_id: LayerId
buffers: list[BufferConfig]
sliding_window_size: int | None = None
num_sink_tokens: int | None = None
@property
def window_size(self) -> int | None: ...
@dataclass(slots=True)
class KVCacheManagerConfig:
tokens_per_block: int
vocab_size: int
cache_tiers: list[CacheTierConfig]
layers: list[AttentionLayerConfig]
max_util_for_resume: float = ...
helix_config: HelixConfig | None = None
# From _block_radix_tree.py
def gen_multi_modal_tokens(
id_offset: int, multi_modal_data_digest: bytes, num_tokens: int
) -> list[TokenIdExt]: ...
# From _core/_kv_cache.py
class _Status(enum.Enum):
ACTIVE = enum.auto()
SUSPENDED = enum.auto()
CLOSED = enum.auto()
IndexSeq = array.array[int] | memoryview[int]
class _KVCache:
Status: ClassVar[Type[_Status]]
id: Any
def __init__(
self,
manager: "KVCacheManager",
lora_task_id: int | None,
input_tokens: Sequence[TokenIdExt] | None,
id: Any,
custom_priority_callback: Callable[[int, Any], Priority],
) -> None: ...
def set_page_index_buf(
self, beam_idx: BeamIndex, layer_group_id: LayerGroupId, buf: memoryview | None
) -> None: ...
@property
def manager(self) -> "KVCacheManager": ...
@property
def cuda_stream(self) -> CudaStream: ...
@cuda_stream.setter
def cuda_stream(self, cuda_stream: CudaStream) -> None: ...
@property
def finish_event(self) -> Any: ...
@property
def num_blocks(self) -> int: ...
def close(self) -> None: ...
@property
def beam_width(self) -> BeamIndex: ...
@beam_width.setter
def beam_width(self, beam_width: BeamIndex) -> None: ...
def get_page_indices(self, layer_group_id: int, beam_id: BeamIndex = ...) -> IndexSeq: ...
def get_all_page_indices(
self, beam_id: BeamIndex, buf_ids: Iterable[tuple[LayerId, DataRole]]
) -> Iterator[IndexSeq]: ...
def resize(self, capacity: int | None, history_length: int | None = None) -> bool: ...
@property
def capacity(self) -> int: ...
@capacity.setter
def capacity(self, capacity: int) -> None: ...
@property
def history_length(self) -> int: ...
@history_length.setter
def history_length(self, history_length: int) -> None: ...
def commit(
self,
accepted_input_tokens: Sequence[TokenIdExt],
beam_search_indices: Sequence[int] | None = None,
) -> None: ...
@property
def num_committed_tokens(self) -> int: ...
def stop_committing(self) -> None: ...
def suspend(self) -> None: ...
def resume(self, cuda_stream: CudaStream | None = None) -> bool: ...
@property
def status(self) -> _Status: ...
@property
def is_active(self) -> bool: ...
@property
def tokens_per_block(self) -> int: ...
# From _core/_kv_cache_manager.py
class KVCacheManager:
def __init__(self, config: KVCacheManagerConfig) -> None: ...
def clear_reusable_blocks(self) -> None: ...
def get_mem_pool_base_address(self, layer_id: LayerId, data_role: DataRole) -> MemAddress: ...
def get_page_stride(self, layer_id: LayerId, data_role: DataRole) -> int: ...
def get_page_index_upper_bound(self, layer_id: LayerId, data_role: DataRole) -> int: ...
def create_kv_cache(
self,
lora_task_id: int | None = None,
input_tokens: Sequence[TokenIdExt] | None = None,
id: Any = None,
custom_priority_callback: Callable[[int, Any], Priority] = ...,
) -> _KVCache: ...
def resize(self, cache_level: CacheLevel, quota: int, best_efforts: bool = False) -> bool: ...
def get_quota(self, cache_level: CacheLevel) -> int: ...
@property
def cache_tier_list(self) -> tuple[CacheTier, ...]: ...
@property
def tokens_per_block(self) -> int: ...
@property
def allow_seq_rebasing(self) -> bool: ...
@property
def enable_partial_match(self) -> bool: ...
def get_layer_group_id(self, layer_id: LayerId) -> int: ...
@property
def layer_grouping(self) -> tuple[tuple[LayerId, ...], ...]: ...
def clamp_max_seq_len_for_mem(self, batch_size: int, model_max_seq_len: int) -> int: ...

View File

@ -0,0 +1,437 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
from typing import TYPE_CHECKING, Iterator, Sequence, TypeVar, cast
from . import rawref
from ._common import NDEBUG, BlockOrdinal, PageStatus, TokenId, TokenIdExt
from ._life_cycle_registry import LifeCycle, LifeCycleId, LifeCycleRegistry
from ._utils import TypedIndexList, chunked, filled_list, unwrap_rawref
if TYPE_CHECKING:
from ._page import CommittedPage
BlockKey = bytes
# id_offset is usually vocab_size
def gen_multi_modal_tokens(
id_offset: int, multi_modal_data_digest: bytes, num_tokens: int
) -> list[TokenIdExt]:
assert num_tokens > 0
# Alternatively, we could also use (multi_modal_data_digest + i.to_bytes(8, 'little')) or its hash
# digest as token id.
# The implementation below is faster and also works because KV cache reuse of a token is with a
# precondition that all previous tokens also match. So only the first multi-modal token id needs to
# be unique.
return [
multi_modal_data_digest if i == 0 else TokenId(id_offset + i) for i in range(num_tokens)
]
class Hasher:
__slots__ = "_hasher"
_hasher: "hashlib._Hash"
def __init__(self, data: int | bytes | None | Sequence[int | bytes] = None) -> None:
self._hasher = hashlib.sha256()
if data is not None:
self.update(data)
# This function is perf-critical. Expect compromised code quality.
def update(self, data: int | bytes | Sequence[int | bytes]) -> "Hasher":
if type(data) is int:
assert NDEBUG or (data >= 0 and data < (1 << 64))
self._hasher.update(data.to_bytes(8, "little"))
elif type(data) is bytes:
self._hasher.update(data)
else:
assert isinstance(data, Sequence)
for item in data:
assert (
NDEBUG or (type(item) is int and (0 <= item < (1 << 64))) or type(item) is bytes
)
self._hasher.update(item.to_bytes(8, "little") if isinstance(item, int) else item)
return self
@property
def digest(self) -> bytes:
return self._hasher.digest()
TokenBlock = list[TokenIdExt]
def sequence_to_blockchain_keys(
tokens_per_block: int, lora_task_id: int | None, tokens: Sequence[TokenIdExt]
) -> Iterator[tuple[TokenBlock, BlockKey]]:
digest = Hasher(lora_task_id).digest
yield [], digest
for token_block in chunked(tokens, tokens_per_block):
digest = Hasher(digest).update(token_block).digest
yield token_block, digest
Child = TypeVar("Child", bound="Block | RootBlock")
Children = dict[BlockKey, Child]
def get_tree(block: "RootBlock | Block") -> "BlockRadixTree":
node = block
while not isinstance(node, BlockRadixTree):
node = node.prev
return node
def remove_subtree(root: "RootBlock | Block") -> list[rawref.ref["CommittedPage"]]:
# taking O(1) space
# remove leaf blocks one by one, in post-order
ret: list[rawref.ref["CommittedPage"]] = []
block: "RootBlock | Block" = root
while True:
if block.next:
block = next(iter(block.next.values()))
else:
if isinstance(block, Block):
ret.extend(p for p in block.storage if p is not None)
block.storage = filled_list(None, block.num_life_cycles)
assert isinstance(block, RootBlock) or all(page is None for page in block.storage), (
"Storage is not cleared, yet"
)
if block._prev() is None:
assert block is root
break
prev_block: Block | RootBlock | BlockRadixTree = block.prev
# Because Block.__del__() may remove RootBlock from BlockRadixTree, we need to check here.
# It may not be in prev_block.next when block is RootBlock.
if block.key in prev_block.next:
prev_block.next.pop(block.key)
if block is root:
break
assert not isinstance(prev_block, BlockRadixTree)
block = prev_block
return ret
def traverse_post_order(root: "Block") -> Iterator["Block"]:
"post-order traversal of the subtree rooted at root"
stack: list[Iterator[Block]] = []
block: Block | None = root
while True:
assert block is not None
if block.next:
child_iter = iter(block.next.values())
stack.append(child_iter)
block = next(child_iter)
else:
yield (last_yielded := block)
while stack and (block := next(stack[-1], None)) is None:
yield (last_yielded := cast(Block, last_yielded.prev))
stack.pop()
if not stack:
break
def find_best_partial_match_in_next_nodes(
block: "Block | RootBlock", tokens: TokenBlock
) -> tuple["Block | None", int]:
"""
Among all child nodes (self.next), finds the one whose tokens have the longest leading match with the given tokens.
Returns a tuple of (best_block, num_matched_tokens).
If no child matches any tokens, returns (None, 0).
"""
if len(block.next) >= 32:
# TODO: build a database to accelerate partial matching. TRTLLM-7784
# For now, it might be too slow to iterate over all children, so let's just skip.
return None, 0
best_block = None
best_match_len = 0
for b in block.next.values():
match_len = b._partial_match_this_node(tokens)
if match_len > best_match_len:
best_match_len = match_len
best_block = b
return best_block, best_match_len
class DuplicateKeyError(Exception):
"Another block with the same key already exists"
key: BlockKey
def __init__(self, key: BlockKey) -> None:
super().__init__(f"Block with key {key.hex()} already exists")
self.key = key
class UselessBlockError(Exception):
block: "Block"
def __init__(self, block: "Block") -> None:
super().__init__(
f"Block is useless because all its tokens are covered by another block with key = {block.key.hex()}"
)
self.block = block
def _add_or_get_existing(
parent: "RootBlock | Block", tokens: Sequence[TokenIdExt]
) -> "Block | None":
try:
return Block(tokens, parent)
except DuplicateKeyError as e:
return parent.next[e.key]
except UselessBlockError:
return None
class RootBlock:
__slots__ = ("_prev", "key", "next", "lora_task_id", "__rawref__")
key: BlockKey
lora_task_id: int | None
_prev: rawref.ref["BlockRadixTree"]
next: Children["Block"]
__rawref__: rawref.ref["RootBlock"]
def __init__(self, lora_task_id: int | None, prev: "BlockRadixTree") -> None:
self.key = self.make_key(lora_task_id)
assert self.key not in prev.next, "Root block already exists"
self.lora_task_id = lora_task_id
self._prev = rawref.ref(prev)
self.next = {}
self.__rawref__ = rawref.NULL
prev.next[self.key] = self
def __del__(self) -> None:
self.__rawref__.invalidate()
@property
def ordinal(self) -> BlockOrdinal:
return BlockOrdinal(-1)
@property
def prev(self) -> "BlockRadixTree":
return unwrap_rawref(self._prev)
@property
def num_life_cycles(self) -> LifeCycleId:
return self.prev.num_life_cycles
@property
def tokens_per_block(self) -> int:
return self.prev.tokens_per_block
@staticmethod
def make_key(lora_task_id: int | None) -> BlockKey:
return Hasher(lora_task_id).digest
class Block:
"""
A block of tokens. Manages data for all layers.
"""
__slots__ = ("key", "tokens", "ordinal", "_prev", "next", "storage", "__rawref__")
key: BlockKey
tokens: Sequence[TokenIdExt]
ordinal: BlockOrdinal
_prev: rawref.ref["Block | RootBlock"]
next: Children["Block"]
__rawref__: rawref.ref["Block"]
# indexed with LifeCycleId
storage: TypedIndexList[LifeCycleId, rawref.ref["CommittedPage"] | None]
@staticmethod
def make_key(prev_key: BlockKey, tokens: Sequence[TokenIdExt]) -> BlockKey:
return Hasher(prev_key).update(tokens).digest
def __init__(self, tokens: Sequence[TokenIdExt], prev: "Block | RootBlock") -> None:
assert prev.tokens_per_block == prev.prev.tokens_per_block, "prev must be a full block"
self.key = self.make_key(prev.key, tokens)
self.tokens = tokens
self.ordinal = BlockOrdinal(prev.ordinal + 1)
self._prev = rawref.ref(prev)
self.next = {}
self.storage = filled_list(None, prev.num_life_cycles)
self.__rawref__ = rawref.NULL
# a Block is useless if all its tokens are covered by a sibling block. Raise UselessBlockError if so.
if self.key in prev.next:
raise UselessBlockError(prev.next[self.key])
if len(tokens) < self.tokens_per_block:
# @TODO: when we have the database for find_best_partial_match_in_next_nodes, we may use
# that for faster check.
for b in prev.next.values():
if b.tokens[: len(tokens)] == tokens:
raise UselessBlockError(b)
# If there are sibling blocks fully covered by this block, remove them.
to_remove = []
for k, b in prev.next.items():
if len(b.tokens) < len(tokens) and tokens[: len(b.tokens)] == b.tokens:
assert NDEBUG or (not b.is_full and b is not self and b.key == k and not b.next)
to_remove.append(k)
for k in to_remove:
b = prev.next.pop(k)
assert b.is_orphan # _KVCache may still hold it.
# prev.next keeps a strong ref to this _Block, so no need to remove self from prev.next in __del__().
prev.next[self.key] = self
def __del__(self) -> None:
for ref in self.storage:
if ref is not None and ref() is not None:
page = unwrap_rawref(ref)
if page.status == PageStatus.DROPPABLE:
if page.scheduled_for_eviction:
page.manager.exclude_from_eviction(page)
if self._prev() is not None and isinstance(self.prev, RootBlock) and not self.prev.next:
self.prev.prev.next.pop(self.prev.key)
self.__rawref__.invalidate()
def _partial_match_this_node(self, tokens: TokenBlock) -> int:
"""
Returns the number of leading tokens that match between the given tokens and this block's tokens.
"""
for i, (a, b) in enumerate(zip(tokens, self.tokens)):
if a != b:
return i
return min(len(tokens), len(self.tokens))
@property
def num_life_cycles(self) -> LifeCycleId:
return LifeCycleId(len(self.storage))
@property
def prev(self) -> "Block | RootBlock":
return unwrap_rawref(self._prev)
def unset_page(self, lc_idx: LifeCycleId, lc: LifeCycle) -> None:
if self.storage[lc_idx] is None:
return
ordinal = self.ordinal
self.storage[lc_idx] = None
if lc.window_size is None or ordinal < lc.num_sink_blocks:
pages = remove_subtree(self)
for r in pages:
if r() is not None:
page = unwrap_rawref(r)
assert page.status == PageStatus.DROPPABLE
if page.scheduled_for_eviction:
page.manager.exclude_from_eviction(page)
# It's possible to implement more sophisticated logic to remove useless blocks for SWA, e.g.
# check if consecutive available blocks is sufficient for window_size. (TRTLLM-8802)
# But for simplicity, we leave it for now.
curr = self
while (
(isinstance(curr, Block) and curr.storage[lc_idx] is None)
and not curr.next
and curr._prev() is not None
):
if curr.key in curr.prev.next:
curr.prev.next.pop(curr.key)
curr = curr.prev
@property
def tokens_per_block(self) -> int:
# we assume non-leaf blocks are always full.
prev = self.prev
return prev.tokens_per_block if isinstance(prev, RootBlock) else len(prev.tokens)
@property
def is_full(self) -> bool:
return len(self.tokens) == self.tokens_per_block
@property
def is_orphan(self) -> bool:
return self.key not in self.prev.next or self.prev.next[self.key] is not self
class BlockRadixTree:
__slots__ = ("_life_cycles", "_tokens_per_block", "next", "__rawref__")
_life_cycles: LifeCycleRegistry
_tokens_per_block: int
next: Children[RootBlock]
__rawref__: rawref.ref["BlockRadixTree"]
def __init__(self, life_cycles: LifeCycleRegistry, tokens_per_block: int) -> None:
self._life_cycles = life_cycles
self._tokens_per_block = tokens_per_block
self.next = {}
self.__rawref__ = rawref.NULL
def __del__(self) -> None:
self.__rawref__.invalidate()
def add_or_get_existing(self, lora_task_id: int | None) -> RootBlock:
key = RootBlock.make_key(lora_task_id)
if key in self.next:
return self.next[key]
return RootBlock(lora_task_id, self)
@property
def tokens_per_block(self) -> int:
return self._tokens_per_block
@property
def life_cycles(self) -> LifeCycleRegistry:
return self._life_cycles
@property
def num_life_cycles(self) -> LifeCycleId:
return self.life_cycles.size
def clear(self) -> list[rawref.ref["CommittedPage"]]:
# taking O(1) space
# remove leaf blocks one by one, in post-order
ret: list[rawref.ref["CommittedPage"]] = []
while self.next:
block = next(iter(self.next.values()))
ret.extend(remove_subtree(block))
assert not self.next
return ret
# yields tuples of (block, num_matched_tokens). num_matched_tokens should be equal to
# tokens_per_block except the last one.
def match(
self,
lora_task_id: int | None,
tokens: Sequence[TokenIdExt],
enable_partial_match: bool = False,
) -> Iterator[tuple[Block, int]]:
block: Block | RootBlock | BlockRadixTree = self
mismatched_token_block: TokenBlock = []
for token_block, key in sequence_to_blockchain_keys(
self._tokens_per_block, lora_task_id, tokens
):
if key in block.next:
block = block.next[key]
if token_block:
assert isinstance(block, Block)
yield block, len(token_block)
else:
mismatched_token_block = token_block
break
if mismatched_token_block and enable_partial_match:
partial_block, match_len = find_best_partial_match_in_next_nodes(
cast(Block | RootBlock, block), mismatched_token_block
)
if partial_block is not None:
block = partial_block
yield block, match_len
def _check_sanity(self) -> bool:
raise NotImplementedError(
"[KVCacheManager] Check if there are any unusable blocks that should have been removed."
)

View File

@ -0,0 +1,85 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import os
from dataclasses import dataclass
from typing import Final, NewType
NDEBUG: Final[int] = int(os.environ.get("TLLM_KV_CACHE_MANAGER_V2_DEBUG", "0")) == 0
class PageStatus(enum.IntEnum):
LOCKED = 0 # Required in GPU. Eviction/dropping not allowed
HELD = 1 # Allow eviction but not dropping
DROPPABLE = 2 # Allow eviction and dropping
# Can extend to more tiers in the future, e.g. object storage like AWS S3.
class CacheTier(enum.IntEnum):
GPU_MEM = 0
HOST_MEM = 1
DISK = 2
CacheLevel = NewType("CacheLevel", int)
GPU_LEVEL: Final[CacheLevel] = CacheLevel(0)
# Normal token id that falls in the tokenizer vocabulary.
TokenId = NewType("TokenId", int)
# For multi-modal tokens, we can handle it in either of the following ways:
# 1. Hash combine image digest and local_token_id, then use digest for every multi-modal token.
# 2. Use digest only for the first multi-modal token, and use int(vocab_size + local_token_id) for the rest.
# 3. Hash the multi-modal token embedding data and use the digest as TokenIdExt for every multi-modal token.
# If we do this, we can't skip the encoder.
TokenIdExt = TokenId | bytes
BlockOrdinal = NewType("BlockOrdinal", int)
BlockOrdinalT = type(BlockOrdinal(0))
LayerId = NewType("LayerId", int)
CudaStream = NewType("CudaStream", int)
BeamIndex = NewType("BeamIndex", int)
UserId = NewType("UserId", int)
MemAddress = NewType("MemAddress", int)
FileDescriptor = NewType("FileDescriptor", int)
BAD_FILE_DESCRIPTOR: Final[FileDescriptor] = FileDescriptor(-1)
PageIndex = NewType("PageIndex", int)
BAD_PAGE_INDEX: Final[PageIndex] = PageIndex(-1)
@dataclass(slots=True, frozen=True)
class DiskAddress:
fd: FileDescriptor
pos: int
Address = MemAddress | DiskAddress
SlidingWindowSize = int | None
Priority = NewType("Priority", int)
PRIORITY_MIN: Final[Priority] = Priority(0)
PRIORITY_MAX: Final[Priority] = Priority(100)
PRIORITY_DEFAULT: Final[Priority] = Priority(35)

View File

@ -0,0 +1,148 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Currently, our nvfp4 kernels require that KV data and its corresponding KV block scale use the same
# block index, but different base address.
# As the ratio between KV data size and KV block scale size is fixed, we can simply use a pool with
# smaller block size and the same number of blocks for block scale.
import os
from dataclasses import dataclass, field
from typing import NewType, Protocol
from ._common import CacheTier, LayerId
# The data role of a buffer inside one layer.
# Must be unique for each buffer inside a layer.
# Examples: "key", "value", "key_block_quant", "value_block_quant".
DataRole = NewType("DataRole", str)
class CacheTierConfig(Protocol):
"""Protocol for cache tier configuration."""
quota: int # in bytes
@property
def tier(self) -> CacheTier: ...
def assert_valid(self) -> None: ...
@dataclass(slots=True)
class GpuCacheTierConfig:
quota: int # in bytes
@property
def tier(self) -> CacheTier:
return CacheTier.GPU_MEM
def assert_valid(self) -> None:
assert self.quota > 0, "Quota must be positive"
@dataclass(slots=True)
class HostCacheTierConfig:
quota: int # in bytes
@property
def tier(self) -> CacheTier:
return CacheTier.HOST_MEM
def assert_valid(self) -> None:
assert self.quota > 0, "Quota must be positive"
@dataclass(slots=True)
class DiskCacheTierConfig:
quota: int # in bytes
path: str # a folder where we will store data as files
@property
def tier(self) -> CacheTier:
return CacheTier.DISK
def assert_valid(self) -> None:
assert self.quota > 0, "Quota must be positive"
assert os.path.isdir(self.path), (
f"Disk path {self.path} does not exist or is not a directory"
)
@dataclass(slots=True)
class BufferConfig:
role: DataRole
size: int
@dataclass(slots=True)
class AttentionLayerConfig:
layer_id: LayerId
# Each page can have multiple sub-pages, e.g. separate K and V data, block quantization scales for K and/or V, etc.
# KV cache manager will automatically group sub-pages of the same size, and redirect pages of different sizes to
# different memory pools
# BufferConfig.role should not duplicate
buffers: list[BufferConfig]
# Note that we use None to represent "no sliding window". Sink tokens are excluded.
sliding_window_size: int | None = None
num_sink_tokens: int | None = None
@property
def window_size(self) -> int | None:
return self.sliding_window_size
def __post_init__(self) -> None:
assert len(set(buffer.role for buffer in self.buffers)) == len(self.buffers), (
"duplicate buffer role"
)
@dataclass(slots=True)
class HelixConfig:
helix_group_size: int
helix_gpu_rank: int
# number of tokens in one helix shard
helix_shard_size: int
# must be the same for all ranks in the same helix group and different for different helix groups.
shared_comm_port: int
@dataclass(slots=True)
class KVCacheManagerConfig:
"""
Configuration for the KV cache manager.
"""
tokens_per_block: int
# if you have p-tuning tokens, include them. Only needed for multi-modal.
vocab_size: int
# cache tiers are sorted from warm to cold. The first one must be GPU memory.
cache_tiers: list[CacheTierConfig]
# AttentionLayerConfig.layer_id should not duplicate
layers: list[AttentionLayerConfig]
# When memory utilization is above this threshold, KV cache resuming will fail. This helps
# reserving some memory for KVCache growth and avoids frequent suspend/resume for dynamic batch size.
max_util_for_resume: float = field(default=0.9)
# unsupported yet
helix_config: HelixConfig | None = field(default=None)
def __post_init__(self) -> None:
assert self.cache_tiers and self.cache_tiers[0].tier == CacheTier.GPU_MEM
assert len(set(layer.layer_id for layer in self.layers)) == len(self.layers), (
"duplicate layer id"
)

View File

@ -0,0 +1,374 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import sys
import threading
from _thread import LockType
from collections.abc import Callable, Iterator
from dataclasses import dataclass
# avoid importing the whole tensorrt_llm module, which takes time during debugging.
from importlib.util import find_spec
from pathlib import Path
from typing import ClassVar, NamedTuple, Sequence, cast
import cuda.bindings.driver as drv
from dynamic_path_manager import DynamicPathManager
from ._common import Address, CacheTier, CudaStream, MemAddress
from ._utils import CachedCudaEvent, HomoTuple, HostMem, _unwrap, div_up, stream_wait_events
if "tensorrt_llm" in sys.modules:
from tensorrt_llm.bindings.internal.batch_manager.kv_cache_manager_v2_utils import ( # noqa # type: ignore
DiskAddress,
DiskToDiskTask,
DiskToHostTask,
HostToDiskTask,
MemToMemTask,
copy_device_to_device,
copy_device_to_host,
copy_disk_to_disk,
copy_disk_to_host,
copy_host_to_device,
copy_host_to_disk,
copy_host_to_host,
)
else:
# fast path for dev, avoids importing the whole tensorrt_llm module
spec = find_spec("kv_cache_manager_v2")
assert spec is not None and spec.origin is not None
with DynamicPathManager(str(Path(spec.origin).parent.parent.parent), clear_cache=False):
from bindings.internal.batch_manager.kv_cache_manager_v2_utils import ( # noqa
DiskAddress,
DiskToDiskTask,
DiskToHostTask,
HostToDiskTask,
MemToMemTask,
copy_device_to_device,
copy_device_to_host,
copy_disk_to_disk,
copy_disk_to_host,
copy_host_to_device,
copy_host_to_disk,
copy_host_to_host,
)
class CopyTask(NamedTuple):
dst: Address
src: Address
def _copy_gpu_to_gpu(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream):
_unwrap(
drv.CUresult(
copy_device_to_device([MemToMemTask(dst, src) for dst, src in tasks], num_bytes, stream)
)
)
def _copy_host_to_host(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream):
_unwrap(
drv.CUresult(
copy_host_to_host([MemToMemTask(dst, src) for dst, src in tasks], num_bytes, stream)
)
)
def _copy_disk_to_disk(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream):
_unwrap(
drv.CUresult(
copy_disk_to_disk(
[
DiskToDiskTask(
DiskAddress(
cast(DiskAddress, dst).fd,
cast(DiskAddress, dst).pos,
),
DiskAddress(
cast(DiskAddress, src).fd,
cast(DiskAddress, src).pos,
),
)
for dst, src in tasks
],
num_bytes,
stream,
)
)
)
def _copy_gpu_to_host(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream):
_unwrap(
drv.CUresult(
copy_device_to_host([MemToMemTask(dst, src) for dst, src in tasks], num_bytes, stream)
)
)
def _copy_host_to_gpu(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream):
_unwrap(
drv.CUresult(
copy_host_to_device([MemToMemTask(dst, src) for dst, src in tasks], num_bytes, stream)
)
)
def _copy_disk_to_host(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream):
_unwrap(
drv.CUresult(
copy_disk_to_host(
[
DiskToHostTask(
cast(MemAddress, dst),
DiskAddress(cast(DiskAddress, src).fd, cast(DiskAddress, src).pos),
)
for dst, src in tasks
],
num_bytes,
stream,
)
)
)
def _copy_host_to_disk(tasks: Sequence[CopyTask], num_bytes: int, stream: CudaStream):
_unwrap(
drv.CUresult(
copy_host_to_disk(
[
HostToDiskTask(
DiskAddress(
cast(DiskAddress, dst).fd,
cast(DiskAddress, dst).pos,
),
cast(MemAddress, src),
)
for dst, src in tasks
],
num_bytes,
stream,
)
)
)
Copier = Callable[[Sequence[CopyTask], int, CudaStream], None]
def get_copier(dst: CacheTier, src: CacheTier) -> Copier | HomoTuple[Copier]:
copiers: HomoTuple[HomoTuple[Copier | HomoTuple[Copier]]] = (
# dst = GPU_MEM
(
_copy_gpu_to_gpu, # src = GPU_MEM
_copy_host_to_gpu, # src = HOST_MEM
(_copy_disk_to_host, _copy_host_to_gpu), # src = DISK
),
# dst = HOST_MEM
(
_copy_gpu_to_host, # src = GPU_MEM
_copy_host_to_host, # src = HOST_MEM
_copy_disk_to_host, # src = DISK
),
# dst = DISK
(
(_copy_gpu_to_host, _copy_host_to_disk), # src = GPU_MEM
_copy_host_to_disk, # src = HOST_MEM
_copy_disk_to_disk, # src = DISK
),
)
return copiers[dst][src]
@dataclass(slots=True)
class GrainMetadata:
mutex: LockType
ready_event: CachedCudaEvent # protects the buffer grain.
class StagingBuffer:
__slots__ = ("manager", "min_size", "max_size", "_size", "start_grain", "stream")
manager: "StagingBufferManager"
min_size: int
max_size: int
_size: int
start_grain: int
stream: CudaStream
def __init__(
self, manager: "StagingBufferManager", min_size: int, max_size: int, stream: CudaStream
):
self.manager = manager
self.min_size = min_size
self.max_size = max_size
self.stream = stream
@property
def address(self) -> MemAddress:
return MemAddress(self.manager.buffer.address + self.start_grain * self.manager.GRANULARITY)
@property
def size(self) -> int:
return self._size
@property
def num_grains(self) -> int:
return div_up(self._size, self.manager.GRANULARITY)
@property
def grains(self) -> list[GrainMetadata]:
return self.manager.grains[self.start_grain : self.start_grain + self.num_grains]
def __enter__(self) -> "StagingBuffer":
manager = self.manager
if self.min_size > manager.size:
raise ValueError(f"Requested min_size {self.min_size} is too large for the manager")
with manager.mutex:
self._size = min(self.max_size, manager._suggest_next_max_size_unsafe())
self.start_grain = manager.next
manager.next += self.num_grains
assert manager.next <= manager.num_grains
if manager.next == manager.num_grains:
manager.next = 0
def lock_and_consume_events() -> Iterator[CachedCudaEvent]:
for grain in self.grains:
grain.mutex.acquire()
yield grain.ready_event
grain.ready_event = CachedCudaEvent.NULL
stream_wait_events(self.stream, lock_and_consume_events())
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
event = CachedCudaEvent(self.stream)
for grain in reversed(self.grains):
grain.ready_event = event
grain.mutex.release()
class StagingBufferManager:
__slots__ = ("mutex", "buffer", "grains", "next")
GRANULARITY: ClassVar[int] = 1 << 20
mutex: LockType
buffer: HostMem
grains: list[GrainMetadata]
next: int
def __init__(self, size: int) -> None:
assert size % self.GRANULARITY == 0
self.mutex = threading.Lock()
num_grains = size // self.GRANULARITY
self.buffer = HostMem(size)
self.grains = [
GrainMetadata(threading.Lock(), CachedCudaEvent.NULL) for _ in range(num_grains)
]
self.next = 0
@property
def size(self) -> int:
"Requesting more than this will fail."
assert len(self.grains) * self.GRANULARITY == self.buffer.size
return self.buffer.size
@property
def num_grains(self) -> int:
return len(self.grains)
def _suggest_next_max_size_unsafe(self) -> int:
"Requesting more than this may degrade performance. Must be called with self.mutex held."
return self.GRANULARITY * (self.num_grains - self.next)
# max_size is just a hint, the actual size may be smaller.
def new(self, min_size: int, max_size: int, stream: CudaStream) -> StagingBuffer:
"""
min_size is the min required size. max_size is for best efforts. Your should query the actual
size after entering the context.
"""
return StagingBuffer(self, min_size, max_size, stream)
class CopyEngine:
__slots__ = ("_staging_buffer_manager",)
_staging_buffer_manager: StagingBufferManager | None
def __init__(self) -> None:
self._staging_buffer_manager = None
def close(self) -> None:
self._staging_buffer_manager = None
@property
def staging_buffer_manager(self) -> StagingBufferManager:
if self._staging_buffer_manager is None:
self._staging_buffer_manager = StagingBufferManager(64 << 20)
return self._staging_buffer_manager
# @TODO: Use a dedicated stream for each different Copier, take set[CachedCudaEvent] instead of
# stream, and return a new CachedCudaEvent.
def transfer(
self,
dst_cache_tier: CacheTier,
src_cache_tier: CacheTier,
num_bytes: int,
tasks: Sequence[CopyTask],
stream: CudaStream,
) -> None:
copier = get_copier(dst_cache_tier, src_cache_tier)
if not isinstance(copier, tuple):
return copier(tasks, num_bytes, stream)
assert len(copier) == 2, "for now, we only support 2 copiers via host memory"
manager = self.staging_buffer_manager
remaining = tasks
while remaining:
with manager.new(num_bytes, num_bytes * len(remaining), stream) as buf:
addr = buf.address
n = buf.size // num_bytes
assert n <= len(remaining)
batch = remaining[:n]
copier[0](
[
CopyTask(MemAddress(addr + num_bytes * i), t.src)
for i, t in enumerate(batch)
],
num_bytes,
buf.stream,
)
copier[1](
[
CopyTask(t.dst, MemAddress(addr + num_bytes * i))
for i, t in enumerate(batch)
],
num_bytes,
buf.stream,
)
remaining = remaining[n:]
_copy_engine = CopyEngine()
atexit.register(_copy_engine.close)
def batched_copy(
dst_cache_tier: CacheTier,
src_cache_tier: CacheTier,
num_bytes: int,
tasks: Sequence[CopyTask],
stream: CudaStream,
) -> None:
_copy_engine.transfer(dst_cache_tier, src_cache_tier, num_bytes, tasks, stream)

View File

@ -0,0 +1,20 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .._common import BeamIndex
from ._kv_cache import _KVCache
from ._kv_cache_manager import KVCacheManager
__all__ = ["KVCacheManager", "_KVCache", "BeamIndex"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,218 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable, Sequence
from typing import Any, cast
from .._block_radix_tree import BlockRadixTree
from .._common import (
GPU_LEVEL,
PRIORITY_DEFAULT,
BlockOrdinal,
CacheLevel,
CacheTier,
LayerId,
MemAddress,
PageStatus,
Priority,
TokenIdExt,
)
from .._config import DataRole, KVCacheManagerConfig
from .._life_cycle_registry import LayerGroupId, LifeCycle, LifeCycleId, LifeCycleRegistry
from .._storage._config import create_storage_config
from .._storage._core import PoolGroupIndex
from .._storage_manager import StorageManager
from .._utils import (
HomoTuple,
TypedIndexList,
div_up,
exact_div,
filled_list,
init_cuda_once,
typed_enumerate,
typed_range,
unwrap_rawref,
)
from ._kv_cache import _KVCache
class KVCacheManager:
__slots__ = ("_init_config", "_life_cycles", "_radix_tree", "_storage")
_init_config: KVCacheManagerConfig
_life_cycles: LifeCycleRegistry
_radix_tree: BlockRadixTree
_storage: StorageManager
def __init__(self, config: KVCacheManagerConfig) -> None:
init_cuda_once()
self._init_config = config
self._life_cycles = LifeCycleRegistry(config)
self._radix_tree = BlockRadixTree(self._life_cycles, config.tokens_per_block)
storage_config = create_storage_config(config)
self._storage = StorageManager(self._life_cycles, storage_config)
def __del__(self) -> None:
self.clear_reusable_blocks()
def clear_reusable_blocks(self) -> None:
for ref in self._radix_tree.clear():
assert unwrap_rawref(ref).status == PageStatus.DROPPABLE
self._storage.exclude_from_eviction(unwrap_rawref(ref))
for level in self._storage._levels:
for pg_idx in typed_range(level.storage.num_pool_groups):
assert level.controller.num_evictable_pages(pg_idx) == 0
def get_mem_pool_base_address(self, layer_id: LayerId, data_role: DataRole) -> MemAddress:
"""
Get the base address of the memory pool holding pages for the given layer and data role.
It's guaranteed that for one layer, multiple buffers of the same size have the same base address.
"""
return self._storage.get_mem_pool_base_address(layer_id, data_role)
# Currently always equals to page size. In the future, that will change when kernels support page stride.
def get_page_stride(self, layer_id: LayerId, data_role: DataRole) -> int:
attr = self._storage.get_buffer_attr(layer_id, data_role)
return attr.size
def get_page_index_upper_bound(self, layer_id: LayerId, data_role: DataRole) -> int:
"""
The upper bound of page indices for the given layer and data role.
Note that this is not the same as the max number of pages available for this layer and data role.
Internally, multiple buffers may share one memory pool. The purpose of this API is just in case
users want to wrap the memory pool as a tensor with known shape.
"""
storage = self._storage
lc_id = storage._layer_to_life_cycle_ids[layer_id]
pg_idx = storage.get_pool_group_index(lc_id)
pool_group = storage._levels[GPU_LEVEL].storage._pool_groups[pg_idx]
num_slots = pool_group.num_slots
attr = storage.get_buffer_attr(layer_id, data_role)
pool_idx = attr.pool_index
slot_size = pool_group.slot_size[pool_idx]
return exact_div(slot_size, attr.size) * num_slots - exact_div(attr.offset, attr.size)
def create_kv_cache(
self,
lora_task_id: int | None = None,
input_tokens: Sequence[TokenIdExt] | None = None,
id: Any = None,
custom_priority_callback: Callable[[BlockOrdinal, LifeCycle], Priority] = lambda _,
__: PRIORITY_DEFAULT,
) -> _KVCache:
"""
lora_task_id: match lora_task_id before matching any tokens.
custom_priority_callback: takes block index and layer sliding window size, returns priority.
If priority returned is higher than existing priority for reused blocks, the block priority is updated.
Newly created KV cache is suspended. You need to call resume() with a cuda stream to make it active
& ready in that stream.
Returns None if suspended=False and we don't have enough resource.
This call will attempt to reuse KV cache blocks.
It's user responsibility to remove the last token from prompts if we need to re-compute the token
generated by prefill.
"""
return _KVCache(self, lora_task_id, input_tokens, id, custom_priority_callback)
def resize(self, cache_level: CacheLevel, quota: int, best_efforts: bool = False) -> bool:
"""
If best_efforts is True, we will try to resize the quota to the largest possible value that is
still <= quota, and returns False only when we cannot resize the quota at all.
If best_efforts is False, we will resize the quota to the exact value of quota, and give up
if not possible.
"""
raise NotImplementedError("Not implemented")
def get_quota(self, cache_level: CacheLevel) -> int:
return self._storage._levels[cache_level].storage.total_quota
# sorted by CacheLevel from warm to cold
@property
def cache_tier_list(self) -> HomoTuple[CacheTier]:
return self._storage.cache_tiers
@property
def tokens_per_block(self) -> int:
return self._radix_tree.tokens_per_block
@property
def allow_seq_rebasing(self) -> bool:
"""
If True, when we commit a full block, we will try to find a existing reusable block with the
same tokens and reuse that block instead to save some memory. Intra-batch reuse will be enabled
if this is True.
"""
return True
@property
def enable_partial_match(self) -> bool:
return True
def get_layer_group_id(self, layer_id: LayerId) -> LayerGroupId:
return self._storage._layer_to_life_cycle_ids[layer_id]
@property
def layer_grouping(self) -> HomoTuple[HomoTuple[LayerId]]:
layer_to_life_cycle_ids = self._storage._layer_to_life_cycle_ids
num_life_cycles = self._life_cycles.size
grouping = dict[LifeCycleId, list[LayerId]]({i: [] for i in typed_range(num_life_cycles)})
for layer_id, life_cycle_id in typed_enumerate(layer_to_life_cycle_ids):
grouping[life_cycle_id].append(layer_id)
return tuple(tuple(grouping[i]) for i in typed_range(num_life_cycles))
# @TODO: need updating when dynamic resizing is supported.
def clamp_max_seq_len_for_mem(self, batch_size: int, model_max_seq_len: int) -> int:
"Get the max possible sequence length limited by the GPU memory pools."
assert batch_size > 0
tokens_per_block = self.tokens_per_block
life_cycles = self._life_cycles
storage = self._storage
num_pool_groups = storage.num_pool_groups
remaining_slots = cast(
TypedIndexList[PoolGroupIndex, int],
[storage.num_slots(pg) for pg in typed_range(num_pool_groups)],
)
lc_to_pg_idx = storage._life_cycle_grouping
def get_num_slots(seq_len: int) -> TypedIndexList[PoolGroupIndex, int]:
ret = filled_list(0, num_pool_groups)
for lc_id, lc in life_cycles.items():
stale_range = _KVCache._get_stale_range(tokens_per_block, seq_len, lc)
num_stale_blocks = stale_range[1] - stale_range[0]
num_slots = div_up(seq_len, tokens_per_block) - num_stale_blocks
pg_idx = lc_to_pg_idx[lc_id]
ret[pg_idx] += num_slots
return ret
for pg in typed_range(num_pool_groups):
remaining_slots[pg] -= get_num_slots(1)[pg] * (batch_size - 1)
assert remaining_slots[pg] >= 0
def is_enough(num_blocks: int) -> bool:
return all(
cnt <= rem
for cnt, rem in zip(get_num_slots(num_blocks * tokens_per_block), remaining_slots)
)
assert is_enough(1)
lb = 1
ub = div_up(model_max_seq_len, tokens_per_block)
if is_enough(ub):
return model_max_seq_len
while lb < ub:
mid = (lb + ub) // 2
if is_enough(mid):
lb = mid
else:
ub = mid - 1
return min(lb * tokens_per_block, model_max_seq_len)

View File

@ -0,0 +1,184 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Type
import cuda.bindings.driver as drv
from ._common import MemAddress
from ._utils import ItemHolderWithSharedPool, PooledFactoryBase, _unwrap, div_up
# Physical memory
class NativePhysMemAllocator:
__slots__ = ("_device_id", "_size", "_prop", "_outstanding_handles")
_device_id: int
_size: int
_prop: drv.CUmemAllocationProp
_outstanding_handles: set[int] # allocated but not released
def __init__(self, size: int) -> None:
self._device_id = int(_unwrap(drv.cuCtxGetDevice())) # pyright: ignore
self._size = size
prop = drv.CUmemAllocationProp()
prop.type = drv.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
prop.location.type = drv.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
prop.location.id = self._device_id
self._prop = prop
self._outstanding_handles = set()
def allocate(self) -> drv.CUmemGenericAllocationHandle:
handle: drv.CUmemGenericAllocationHandle = _unwrap(
drv.cuMemCreate(self._size, self._prop, 0)
)
int_handle = int(handle) # pyright: ignore
assert (int_handle not in self._outstanding_handles) and int_handle != 0
self._outstanding_handles.add(int_handle)
return handle
def release(self, handle: drv.CUmemGenericAllocationHandle) -> None:
if handle == drv.CUmemGenericAllocationHandle(0):
return
assert int(handle) in self._outstanding_handles
self._outstanding_handles.remove(int(handle))
try:
_unwrap(drv.cuMemRelease(handle))
except:
print(
f"Failed to release handle {handle}. num_oustanding = {len(self._outstanding_handles)}"
)
raise
@property
def device_id(self) -> int:
return self._device_id
@property
def size(self) -> int:
return self._size
class PhysMem(ItemHolderWithSharedPool[drv.CUmemGenericAllocationHandle]):
__slots__ = ()
class PooledPhysMemAllocator(PooledFactoryBase[drv.CUmemGenericAllocationHandle, PhysMem]):
_Holder: Type[PhysMem] = PhysMem
__slots__ = ("device_id", "phys_mem_size")
device_id: int
phys_mem_size: int
def __init__(self, phys_mem_size: int) -> None:
raw_alloc = NativePhysMemAllocator(phys_mem_size)
self.device_id = raw_alloc.device_id
self.phys_mem_size = phys_mem_size
super().__init__(lambda: raw_alloc.allocate(), lambda handle: raw_alloc.release(handle))
# Virtual memory
class VirtMem:
__slots__ = ("_vm_size", "_allocator", "_address", "_pm_stack", "_access_desc")
_vm_size: int
_allocator: PooledPhysMemAllocator
_address: drv.CUdeviceptr
_pm_stack: list[PhysMem]
_access_desc: drv.CUmemAccessDesc
def __init__(
self, vm_size: int, phys_mem_allocator: PooledPhysMemAllocator, init_num_phys_mem: int = 0
):
assert vm_size % phys_mem_allocator.phys_mem_size == 0
self._allocator = phys_mem_allocator
device_id = phys_mem_allocator.device_id
self._address = _unwrap(drv.cuMemAddressReserve(vm_size, 0, 0, 0))
self._vm_size = vm_size
self._pm_stack = []
self._access_desc = drv.CUmemAccessDesc()
self._access_desc.location.type = drv.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
self._access_desc.location.id = device_id
self._access_desc.flags = drv.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
self.extend(init_num_phys_mem)
@property
def phys_mem_size(self) -> int:
return self._allocator.phys_mem_size
def destroy(self) -> None:
if self._vm_size == 0:
return
while self._pm_stack:
self._pop().close()
_unwrap(drv.cuMemAddressFree(self._address, self._vm_size))
self._address = drv.CUdeviceptr(0)
self._vm_size = 0
def __del__(self) -> None:
self.destroy()
def extend(self, num_phys_mem: int) -> None:
old_num_phys_mem = self.num_phys_mem
try:
for _ in range(num_phys_mem):
self._push(self._allocator.create())
except (
Exception
): # to make realloc behave like normal realloc, we need to rollback if out of memory
while self.num_phys_mem > old_num_phys_mem:
self._pop().close()
raise
def shrink(self, num_phys_mem: int) -> None:
for _ in range(num_phys_mem):
self._pop().close()
# Different from normal realloc, this function never changes the pointer.
def realloc(self, num_bytes: int) -> None:
required_num_phys_mem = div_up(num_bytes, self.phys_mem_size)
if required_num_phys_mem > self.num_phys_mem:
self.extend(required_num_phys_mem - self.num_phys_mem)
elif required_num_phys_mem < self.num_phys_mem:
self.shrink(self.num_phys_mem - required_num_phys_mem)
def _push(self, phy_mem: PhysMem) -> None:
phys_mem_size = self.phys_mem_size
assert phys_mem_size * (len(self._pm_stack) + 1) <= self._vm_size
vm_ptr = drv.CUdeviceptr(self.address + phys_mem_size * len(self._pm_stack))
_unwrap(drv.cuMemMap(vm_ptr, phys_mem_size, 0, phy_mem.handle, 0))
_unwrap(drv.cuMemSetAccess(vm_ptr, phys_mem_size, (self._access_desc,), 1))
self._pm_stack.append(phy_mem)
def _pop(self) -> PhysMem:
assert self._pm_stack
phys_mem_size = self.phys_mem_size
vm_ptr = drv.CUdeviceptr(self.address + phys_mem_size * (len(self._pm_stack) - 1))
_unwrap(drv.cuMemUnmap(vm_ptr, phys_mem_size))
return self._pm_stack.pop()
@property
def mapped_bytes(self) -> int:
return self.phys_mem_size * self.num_phys_mem
@property
def virtual_bytes(self) -> int:
return self._vm_size
@property
def num_phys_mem(self) -> int:
return len(self._pm_stack)
@property
def address(self) -> MemAddress:
return MemAddress(int(self._address))

View File

@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._eviction_controller import ( # noqa: E402
EvictablePage,
EvictionPolicy,
NodeRef,
PerLevelEvictionController,
)
__all__ = ["EvictionPolicy", "PerLevelEvictionController", "EvictablePage", "NodeRef"]

View File

@ -0,0 +1,228 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Protocol, cast
from llist import sllist, sllistnode
from .._common import NDEBUG, CacheLevel, PageStatus, Priority
from .._exceptions import OutOfPagesError
from .._life_cycle_registry import LifeCycleId
from .._storage._core import PoolGroupIndex
from .._utils import (
TypedIndexList,
assert_critical,
make_typed,
noexcept,
typed_enumerate,
typed_len,
unwrap_optional,
)
# @runtime_checkable
class EvictablePage(Protocol):
@property
def cache_level(self) -> CacheLevel: ...
@property
def priority(self) -> Priority: ...
@property
def life_cycle(self) -> LifeCycleId: ...
@property
def status(self) -> PageStatus: ...
def is_committed(self) -> bool: ...
node_ref: "NodeRef | None"
# @runtime_checkable
class NodeRef(Protocol):
@property
def value(self) -> EvictablePage: ...
# @runtime_checkable
class EvictionPolicy(Protocol):
def push(self, page: EvictablePage, evict_first: bool = False) -> NodeRef: ...
def pop(self) -> EvictablePage: ...
# Remove a node so we no longer consider it for eviction. Like pop() but allow removing a node
# that is not the first.
def remove(self, node: NodeRef) -> EvictablePage: ...
def __len__(self) -> int: ...
class LRUEvictionPolicy:
__slots__ = ("_queue",)
_queue: sllist
def __init__(self) -> None:
self._queue = sllist()
def push(self, page: EvictablePage, evict_first: bool = False) -> sllistnode:
assert page.node_ref is None
return self._queue.appendleft(page) if evict_first else self._queue.append(page)
def pop(self) -> EvictablePage:
victim = self._queue.first
assert victim is not None
page = victim.value
self.remove(victim)
return page
def remove(self, node: sllistnode) -> EvictablePage:
# assert isinstance(node, NodeRef) # mypyc does not support runtime_checkable
assert node == node.value.node_ref
return self._queue.remove(node)
def __len__(self) -> int:
return len(self._queue)
# helper class to help add support for priority-based eviction
class PrioritizedEvictionPolicy:
__slots__ = (
"_policy_creator",
"_policies",
)
_policy_creator: Callable[[Priority], EvictionPolicy]
_policies: dict[Priority, EvictionPolicy]
def __init__(self, policy_creator: Callable[[Priority], EvictionPolicy]) -> None:
self._policy_creator = policy_creator
self._policies = {}
def __len__(self) -> int:
return sum(len(policy) for policy in self._policies.values())
def get_policy(self, priority: Priority) -> EvictionPolicy:
if priority not in self._policies:
self._policies[priority] = self._policy_creator(priority)
self._policies = dict(sorted(self._policies.items()))
return self._policies[priority]
def _front_policy(self) -> EvictionPolicy:
return next(iter(self._policies.values()))
def push(self, page: EvictablePage, evict_first: bool = False) -> NodeRef:
return self.get_policy(page.priority).push(page, evict_first)
def pop(self) -> EvictablePage:
return self._front_policy().pop()
def remove(self, node: NodeRef) -> EvictablePage:
page = node.value
policy = self._policies[page.priority]
policy.remove(node)
if not policy:
self._policies.pop(page.priority)
return page
class PrioritizedLRUEvictionPolicy(PrioritizedEvictionPolicy):
__slots__ = ()
def __init__(self) -> None:
super().__init__(lambda priority: LRUEvictionPolicy())
class PerLevelEvictionController: # for one cache level
__slots__ = ("_life_cycle_grouping", "_policies", "_cache_level")
_life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex]
_policies: TypedIndexList[PoolGroupIndex, EvictionPolicy]
_cache_level: CacheLevel
def __init__(
self,
life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex],
cache_level: CacheLevel,
):
self._cache_level = cache_level
self._life_cycle_grouping = life_cycle_grouping
num_pool_groups = max(life_cycle_grouping) + 1
assert num_pool_groups == len(set(life_cycle_grouping))
self._policies = cast(
TypedIndexList, [PrioritizedLRUEvictionPolicy() for _ in range(num_pool_groups)]
)
def __del__(self) -> None:
if not NDEBUG:
assert_critical(
all(len(p) == 0 for p in self._policies), "Eviction controller is not empty"
)
def _get_policy(self, life_cycle: LifeCycleId) -> EvictionPolicy:
pg_idx = self._life_cycle_grouping[life_cycle]
return self._policies[pg_idx]
def schedule_for_eviction(self, page: EvictablePage, evict_first: bool = False):
assert page.node_ref is None and page.cache_level == self._cache_level
page.node_ref = self._get_policy(page.life_cycle).push(page, evict_first)
assert unwrap_optional(page.node_ref).value is page
# If evicting a node makes some other nodes useless, those nodes will be returned as well.
# One example: for SWA, if the number of blocks just makes up one window size, then evicting any of
# them makes the remaining blocks useless.
# Raise if no enough pages to evict. In this case, pages are returned to the eviction queue.
def evict(
self, min_num_pages: TypedIndexList[PoolGroupIndex, int]
) -> TypedIndexList[PoolGroupIndex, list[EvictablePage]]:
assert NDEBUG or len(min_num_pages) == self.num_pool_groups
ret = make_typed(lambda: list[EvictablePage](), self.num_pool_groups)
try:
for pg_idx, count in typed_enumerate(min_num_pages):
policy = self._policies[pg_idx]
if (len(policy) + len(ret[pg_idx])) < count:
raise OutOfPagesError(f"Not enough pages to evict in group {pg_idx}")
while len(ret[pg_idx]) < count:
page = policy.pop()
page.node_ref = None
ret[pg_idx].append(page)
for a, b in zip(ret, self._evict_dependencies(page)):
a.extend(b)
except Exception:
for p in reversed(sum(ret, [])):
self.schedule_for_eviction(p, evict_first=True)
raise
assert all(p.cache_level == self._cache_level for p in sum(ret, [])), (
"Corrupted eviction controller"
)
return ret
def remove(self, node: NodeRef) -> None:
page = node.value
assert page.node_ref == node
self._get_policy(page.life_cycle).remove(node)
page.node_ref = None
# @TODO: implement this
@noexcept
def _evict_dependencies(
self, page: EvictablePage
) -> TypedIndexList[PoolGroupIndex, list[EvictablePage]]:
return make_typed(lambda: list[EvictablePage](), self.num_pool_groups)
def num_evictable_pages(self, pg_idx: PoolGroupIndex) -> int:
return len(self._policies[pg_idx])
@property
def num_pool_groups(self) -> PoolGroupIndex:
return typed_len(self._policies)

View File

@ -0,0 +1,60 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cuda.bindings.driver as drv
class OutOfMemoryError(Exception):
pass
class HostOOMError(OutOfMemoryError):
pass
class DiskOOMError(OutOfMemoryError):
pass
class CuOOMError(OutOfMemoryError):
pass
class LogicError(Exception):
"""
This exception indicates a bug in the code.
"""
def __init__(self, message: str) -> None:
super().__init__(message)
class CuError(Exception):
error_code: drv.CUresult
def __init__(self, error_code: drv.CUresult) -> None:
self.error_code = error_code
err, err_str = drv.cuGetErrorString(error_code)
if err != drv.CUresult.CUDA_SUCCESS:
err_str = "<Failed to get error string with cuGetErrorString>"
super().__init__(f"CUDA driver error: {error_code} ({err_str})")
class ResourceBusyError(Exception):
pass
class OutOfPagesError(Exception):
pass

View File

@ -0,0 +1,90 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterator, NamedTuple, NewType, TypeAlias, cast
from ._common import SlidingWindowSize
from ._config import KVCacheManagerConfig
from ._utils import TypedIndexList, div_up, typed_enumerate
class LifeCycle(NamedTuple):
window_size: SlidingWindowSize
num_sink_blocks: int # div_up(num_sink_tokens, tokens_per_block)
@staticmethod
def make(
window_size: SlidingWindowSize, num_sink_tokens: int | None, tokens_per_block: int
) -> "LifeCycle":
assert tokens_per_block > 0
assert window_size is None or window_size > 0
assert num_sink_tokens is None or num_sink_tokens >= 0
assert num_sink_tokens in (None, 0) or window_size is not None
num_sink_blocks = div_up(num_sink_tokens or 0, tokens_per_block)
return LifeCycle(window_size, num_sink_blocks)
LifeCycleId = NewType("LifeCycleId", int)
# For public exposure
LayerGroupId: TypeAlias = LifeCycleId
class LifeCycleRegistry:
__slots__ = ("_life_cycle_list", "_life_cycle_id_dict")
_life_cycle_list: TypedIndexList[LifeCycleId, LifeCycle]
_life_cycle_id_dict: dict[LifeCycle, LifeCycleId]
def __init__(self, config: KVCacheManagerConfig) -> None:
self._life_cycle_list = cast(TypedIndexList[LifeCycleId, LifeCycle], [])
self._life_cycle_id_dict = dict[LifeCycle, LifeCycleId]()
for layer in config.layers:
details = LifeCycle.make(
layer.window_size, layer.num_sink_tokens, config.tokens_per_block
)
if details not in self._life_cycle_id_dict:
assert len(self._life_cycle_id_dict) == len(self._life_cycle_list), (
"corrupted life cycle registry"
)
self._life_cycle_list.append(details)
self._life_cycle_id_dict[details] = LifeCycleId(len(self._life_cycle_list) - 1)
def get_life_cycle(self, id: LifeCycleId) -> LifeCycle:
return self._life_cycle_list[id]
def get_id(self, life_cycle_details: LifeCycle) -> LifeCycleId:
return self._life_cycle_id_dict[life_cycle_details]
@property
def size(self) -> LifeCycleId:
assert len(self._life_cycle_list) == len(self._life_cycle_id_dict), (
"corrupted life cycle registry"
)
return LifeCycleId(len(self._life_cycle_list))
def __iter__(self) -> Iterator[LifeCycle]:
return iter(self._life_cycle_list)
def __getitem__(self, idx: LifeCycleId) -> LifeCycle:
return self._life_cycle_list[idx]
def items(self) -> Iterator[tuple[LifeCycleId, LifeCycle]]:
return typed_enumerate(self.get())
def get(self) -> TypedIndexList[LifeCycleId, LifeCycle]:
return self._life_cycle_list
def __contains__(self, lc: LifeCycle) -> bool:
return lc in self._life_cycle_id_dict

View File

@ -0,0 +1,463 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple, cast
from . import rawref
from ._block_radix_tree import Block
from ._common import (
BAD_PAGE_INDEX,
GPU_LEVEL,
NDEBUG,
BeamIndex,
BlockOrdinal,
CacheLevel,
PageIndex,
PageStatus,
Priority,
TokenIdExt,
)
if TYPE_CHECKING:
from ._core._kv_cache import _KVCache
from ._storage_manager import StorageManager
from ._eviction_controller import NodeRef
from ._exceptions import LogicError
from ._life_cycle_registry import LifeCycleId
from ._storage._core import Slot
from ._utils import (
CachedCudaEvent,
assert_critical,
filled_list,
get_uniform_attribute,
merge_events,
partition,
stream_wait_events,
unwrap_rawref,
)
ReferenceType = rawref.ReferenceType
# We will have a huge amount of pages for large storage capacity.
# So we prefer inheritance over composition to save some memory.
@dataclass(slots=True)
class Page(Slot):
_manager: ReferenceType["StorageManager"]
life_cycle: LifeCycleId
cache_level: CacheLevel
_priority: Priority
# _holder is either None or a valid rawref.
_holder: ReferenceType["_PageHolder"] | None
node_ref: NodeRef | None
def __del__(self) -> None:
if not NDEBUG:
assert_critical(self.status == PageStatus.DROPPABLE and not self.scheduled_for_eviction)
if self.has_valid_slot:
self.manager.release_slot(self.life_cycle, self.cache_level, self)
@property
def manager(self) -> "StorageManager":
return unwrap_rawref(self._manager)
@property
def priority(self) -> Priority:
return self._priority
# prevent dropping
def hold(self) -> "_PageHolder":
if self._holder is not None:
return unwrap_rawref(self._holder)
holder = _PageHolder(self)
self._holder = rawref.ref(holder)
controller = self.manager
if self.scheduled_for_eviction and not controller.is_evictable(self):
controller.exclude_from_eviction(self)
assert not self.scheduled_for_eviction
return holder
# Prevent eviction. You need to migrate the page to GPU later.
def lock(
self,
kv_cache: "_KVCache",
beam_index: BeamIndex,
ordinal: BlockOrdinal,
life_cycle: LifeCycleId,
skip_wait: bool = False,
) -> "_SharedPageLock":
"If skip wait, you are responsible for making the page ready in kv_cache.cuda_stream."
return self.hold().lock(kv_cache, beam_index, ordinal, life_cycle, skip_wait)
@property
def status(self) -> PageStatus:
if self._holder is None:
return PageStatus.DROPPABLE
lock_ref = unwrap_rawref(self._holder)._lock
if lock_ref is None:
return PageStatus.HELD
assert unwrap_rawref(lock_ref) is not None
return PageStatus.LOCKED
@property
def scheduled_for_eviction(self) -> bool:
return self.node_ref is not None
def is_committed(self) -> bool:
raise LogicError("Unexpected call to this implementation.")
@dataclass(slots=True)
class UncommittedPage(Page):
# @TODO: consider move this to _PageHolder
kv_cache: rawref.ref["_KVCache"]
ordinal: BlockOrdinal
beam_index: BeamIndex
tokens: list[TokenIdExt] = field(default_factory=list)
def is_committed(self) -> bool:
return False
def __init__(
self,
kv_cache: "_KVCache",
ordinal: BlockOrdinal,
life_cycle: LifeCycleId,
cache_level: CacheLevel,
slot: Slot,
beam_index: BeamIndex = BeamIndex(0),
):
self.kv_cache = rawref.ref(kv_cache)
self.ordinal = ordinal
self.beam_index = beam_index
manager = kv_cache.manager
priority = kv_cache._get_priority(ordinal, manager._life_cycles.get_life_cycle(life_cycle))
Page.__init__(
self,
None,
CachedCudaEvent.NULL,
rawref.ref(manager._storage),
life_cycle,
cache_level,
priority,
None,
None,
)
self.set_slot(slot)
def convert_to_committed(self, block: Block) -> "CommittedPage":
"""
Moves the slot to a new committed page and add the new page to the block.
The uncommitted page becomes invalid.
"""
assert not self.scheduled_for_eviction
assert block.storage[self.life_cycle] is None
# If you hit this assertion failure, it's likely because you are using debugpy, which delayed GC
# for _KVCache._take_uncommitted_page(). Disable breakpoints on exceptions to avoid this issue.
assert self.status == PageStatus.DROPPABLE, "Release holder/lock first"
committed_page = CommittedPage(
self.manager, block, self.life_cycle, self.cache_level, self, self.priority
)
self._slot_id = None
self.ready_event = CachedCudaEvent.NULL
assert committed_page.has_valid_slot
block.storage[self.life_cycle] = rawref.ref(committed_page)
return committed_page
def __del__(self) -> None:
def check_page(p: "BlockPage") -> bool:
return p is None or isinstance(p.page, CommittedPage)
if not NDEBUG:
assert_critical(
len(unwrap_rawref(self.kv_cache)._blocks) <= self.ordinal
or check_page(
unwrap_rawref(self.kv_cache)
._blocks[self.ordinal]
.pages[self.beam_index][self.life_cycle]
)
)
Page.__del__(self)
@dataclass(slots=True)
class CommittedPage(Page):
block: rawref.ref["Block"]
__rawref__: rawref.ref["CommittedPage"]
def is_committed(self) -> bool:
return True
def __init__(
self,
storage: "StorageManager",
block: Block,
life_cycle: LifeCycleId,
cache_level: CacheLevel,
slot: Slot,
priority: Priority,
):
self.block = rawref.ref(block)
self.__rawref__ = rawref.NULL
Page.__init__(
self,
None,
CachedCudaEvent.NULL,
rawref.ref(storage),
life_cycle,
cache_level,
priority,
None,
None,
)
self.set_slot(slot)
def __del__(self) -> None:
block = self.block()
# block may be None when rebase happens, i.e. another block with the same key is committed,
# replacing it, but the page is still used by a _KVCache.
if block is not None:
block.unset_page(
self.life_cycle,
self.manager._life_cycles.get_life_cycle(self.life_cycle),
)
Page.__del__(self)
self.__rawref__.invalidate()
@dataclass(slots=True)
class _PageHolder:
"Prevents pages from being dropped."
page: Page
_lock: rawref.ref["_UniqPageLock"] | None = None
__rawref__: rawref.ref["_PageHolder"] = field(default_factory=lambda: rawref.NULL)
def __init__(self, page: Page) -> None:
self.page = page
self._lock = None
self.__rawref__ = rawref.NULL
def __del__(self) -> None:
if not NDEBUG:
assert_critical(self._lock is None)
page = self.page
page._holder = None
# If a held page was in last level cache, it was not scheduled for eviction.
if page.is_committed():
page = cast(CommittedPage, page)
if not page.scheduled_for_eviction:
page.manager.schedule_for_eviction(page)
block = page.block()
if block is None or block.is_orphan:
page.manager.exclude_from_eviction(page)
elif page.scheduled_for_eviction:
page = cast(UncommittedPage, self.page)
page.manager.exclude_from_eviction(self.page)
self.__rawref__.invalidate()
# Prevent eviction. You need to migrate the page to GPU later.
def lock(
self,
kv_cache: "_KVCache",
beam_index: BeamIndex,
ordinal: BlockOrdinal,
life_cycle: LifeCycleId,
skip_wait: bool = False,
) -> "_SharedPageLock":
if self._lock is None:
lock = _UniqPageLock(self)
self._lock = rawref.ref(lock)
else:
lock = unwrap_rawref(self._lock)
if self.page.scheduled_for_eviction:
manager = self.page.manager
manager.exclude_from_eviction(self.page)
assert not self.page.scheduled_for_eviction
return lock.share(kv_cache, beam_index, ordinal, life_cycle, skip_wait)
@dataclass(slots=True)
class _UniqPageLock:
"Locks pages to prevent eviction."
holder: _PageHolder | None
finish_events: list[CachedCudaEvent]
__rawref__: rawref.ref["_UniqPageLock"] = field(default_factory=lambda: rawref.NULL)
def __init__(self, holder: _PageHolder) -> None:
if holder.page.cache_level != CacheLevel(0):
raise ValueError("Lock can be applied only on GPU memory pages.")
self.holder = holder
self.finish_events = []
self.__rawref__ = rawref.NULL
def share(
self,
kv_cache: "_KVCache",
beam_index: BeamIndex,
ordinal: BlockOrdinal,
life_cycle: LifeCycleId,
skip_wait: bool,
) -> "_SharedPageLock":
ret = _SharedPageLock(self, kv_cache, beam_index, ordinal, life_cycle, skip_wait)
return ret
@property
def page(self) -> Page:
assert self.holder is not None
return self.holder.page
def __del__(self) -> None:
page = self.page
if not NDEBUG:
assert_critical(page.cache_level == CacheLevel(0) and not page.scheduled_for_eviction)
page.ready_event = merge_events(self.finish_events)
assert self.holder is not None
self.holder._lock = None
if False:
if page.manager.is_evictable(page):
page.manager.schedule_for_eviction(page)
else:
# Optimized code path:
# delete holder first, so if nobody holds the page elsewhere, it becomes droppable immediately,
# before we hand it over to eviction controller.
self.holder = None
# if it's not droppable, then it means self.holder=None had no impact. We need to schedule it
# for eviction as usual.
if page.status != PageStatus.DROPPABLE and page.manager.is_evictable(page):
page.manager.schedule_for_eviction(page)
self.__rawref__.invalidate()
class LockOwner(NamedTuple):
kv_cache: rawref.ref["_KVCache"]
beam_index: BeamIndex
ordinal: BlockOrdinal
life_cycle: LifeCycleId
@dataclass(slots=True, init=False)
class _SharedPageLock:
_uniq_lock: _UniqPageLock | None
_user: LockOwner
@property
def page(self) -> Page:
assert self._uniq_lock is not None
return self._uniq_lock.page
@property
def holder(self) -> _PageHolder:
assert self._uniq_lock is not None
assert self._uniq_lock.holder is not None
return self._uniq_lock.holder
def __hash__(self) -> int:
return hash(id(self))
def __eq__(self, other: object) -> bool:
return self is other
def __init__(
self,
uniq_lock: _UniqPageLock,
kv_cache: "_KVCache",
beam_index: BeamIndex,
ordinal: BlockOrdinal,
life_cycle: LifeCycleId,
skip_wait: bool,
) -> None:
self._uniq_lock = uniq_lock
if not skip_wait:
self.page.ready_event.wait_in_stream(kv_cache.cuda_stream)
self._user = LockOwner(rawref.ref(kv_cache), beam_index, ordinal, life_cycle)
new_index = self._get_page_index()
old_index = kv_cache._update_page_index(beam_index, ordinal, life_cycle, new_index)
assert old_index == BAD_PAGE_INDEX
def __del__(self) -> None:
if self._uniq_lock is not None:
self.unlock()
def unlock(self) -> Page:
assert self._uniq_lock is not None
page = self.page
self._uniq_lock.finish_events.append(unwrap_rawref(self._user.kv_cache).finish_event)
new_index = BAD_PAGE_INDEX
old_index = unwrap_rawref(self._user.kv_cache)._update_page_index(
self._user.beam_index, self._user.ordinal, self._user.life_cycle, new_index
)
assert NDEBUG or old_index == self._get_page_index()
self._uniq_lock = None
return page
def _get_page_index(self) -> PageIndex:
storage = unwrap_rawref(self._user.kv_cache).manager._storage
num_buffers_per_slot = storage._slot_to_page_indices[self._user.life_cycle]
return PageIndex(self.page.slot_id * num_buffers_per_slot)
BlockPage = _SharedPageLock | _PageHolder | None
class BatchedLockTarget(NamedTuple):
page: Page
beam_index: BeamIndex
ordinal: BlockOrdinal
life_cycle: LifeCycleId
def batched_lock_to_gpu(
kv_cache: "_KVCache", tasks: Sequence[BatchedLockTarget]
) -> list["_SharedPageLock"]:
"Lock pages after migrating all pages to GPU. If migration fails, no locking happens."
storage = kv_cache.manager._storage
assert not tasks or storage is get_uniform_attribute(tasks, lambda p: p.page.manager)
requirements = filled_list(0, storage.num_pool_groups)
scheduled_for_eviction = [t.page.scheduled_for_eviction for t in tasks]
for t, e in zip(tasks, scheduled_for_eviction):
if e:
storage.exclude_from_eviction(t.page)
if t.page.cache_level == GPU_LEVEL:
continue
requirements[storage.get_pool_group_index(t.life_cycle)] += 1
try:
storage.prepare_free_slots(GPU_LEVEL, requirements)
partitioned = partition(
tasks, lambda p: (p.page.cache_level, storage.get_pool_group_index(p.life_cycle))
)
for (lvl, pg_idx), part in partitioned.items():
if lvl == GPU_LEVEL:
continue
storage._batched_migrate(
pg_idx, GPU_LEVEL, lvl, [p.page for p in part], update_src=True
)
except Exception:
for t, e in zip(tasks, scheduled_for_eviction):
if e:
storage.schedule_for_eviction(t.page)
raise
stream_wait_events(kv_cache.cuda_stream, (p.page.ready_event for p in tasks))
return [
page.lock(kv_cache, beam_index, ordinal, life_cycle, skip_wait=True)
for page, beam_index, ordinal, life_cycle in tasks
]

View File

@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._core import CacheLevelStorage
# These are re-exported for external use
__all__ = ["CacheLevelStorage"]

View File

@ -0,0 +1,225 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from dataclasses import dataclass
from typing import NamedTuple, cast
from .._common import LayerId
from .._config import CacheTierConfig, DataRole, KVCacheManagerConfig
from .._life_cycle_registry import LifeCycle, LifeCycleId, LifeCycleRegistry
from .._storage._core import PoolGroupIndex, PoolIndex
from .._utils import (
HomoTuple,
TypedIndexList,
exact_div,
filled_list,
get_uniform_attribute,
is_sorted,
typed_range,
)
class BufferId(NamedTuple):
layer_id: LayerId
role: DataRole
@dataclass(slots=True)
class CoalescedBuffer:
life_cycle_id: LifeCycleId
single_buffer_size: int # identical for all buffers in the same coalesced buffer
buffer_ids: list[BufferId]
@property
def size(self) -> int:
return self.single_buffer_size * len(self.buffer_ids)
@dataclass(slots=True)
class PageConfig:
"""
A page is a group of coalesced buffers. Each coalesced buffer has multiple buffers with the same
size. Multiple coalesced buffers can be in the same page if they share the same life cycle and
coalesced size.
"""
coalesced_buffers: list[CoalescedBuffer]
@property
def _coalesced_size(self) -> int:
return get_uniform_attribute(self.coalesced_buffers, lambda b: b.size)
@property
def slot_size(self) -> int:
return self._coalesced_size * len(self.coalesced_buffers)
@property
def life_cycle_id(self) -> LifeCycleId:
return get_uniform_attribute(self.coalesced_buffers, lambda b: b.life_cycle_id)
@dataclass(slots=True, frozen=True)
class SlotConfig:
"A group of pages for the same life cycle."
pages: HomoTuple[PageConfig]
def __post_init__(self) -> None:
assert is_sorted(self.pages, key=lambda s: s.slot_size, reverse=True)
assert all(
len(p.coalesced_buffers) == len(self.pages[0].coalesced_buffers) for p in self.pages
)
@property
def life_cycle_id(self) -> LifeCycleId:
return get_uniform_attribute(self.pages, lambda s: s.life_cycle_id)
@property
def slot_size_list(self) -> HomoTuple[int]:
return tuple(s.slot_size for s in self.pages)
@dataclass(slots=True, frozen=True)
class PoolGroupConfig:
"""
A group of pools may contain slots (page groups) with different life cycles. They have identical
slot size list, so we can put them in the same group of memory pools.
"""
slots: HomoTuple[SlotConfig]
@property
def slot_size_list(self) -> HomoTuple[int]:
return get_uniform_attribute(self.slots, lambda s: s.slot_size_list)
@dataclass(slots=True, frozen=True)
class BufferAttr:
life_cycle_id: LifeCycleId
pool_index: PoolIndex
offset: int
size: int
@dataclass(slots=True, frozen=True)
class StorageConfig:
cache_tiers: HomoTuple[CacheTierConfig]
pool_groups: HomoTuple[PoolGroupConfig]
@property
def num_life_cycles(self) -> LifeCycleId:
return LifeCycleId(sum(len(pg.slots) for pg in self.pool_groups))
def life_cycle_grouping(self) -> TypedIndexList[LifeCycleId, PoolGroupIndex]:
ret = filled_list(PoolGroupIndex(-1), self.num_life_cycles)
for pg_idx, pg in enumerate(self.pool_groups):
pg_idx = PoolGroupIndex(pg_idx)
for s in pg.slots:
ret[s.life_cycle_id] = pg_idx
return ret
def buffer_attributes(self) -> dict[BufferId, BufferAttr]:
ret = dict[BufferId, BufferAttr]()
for pg in self.pool_groups:
for slot in pg.slots:
life_cycle_id = slot.life_cycle_id
for pool, page in enumerate(slot.pages):
offset = 0
for cb in page.coalesced_buffers:
for b in cb.buffer_ids:
ret[b] = BufferAttr(
life_cycle_id, PoolIndex(pool), offset, cb.single_buffer_size
)
offset += cb.single_buffer_size
return ret
def slot_to_page_indices(self) -> TypedIndexList[LifeCycleId, int]:
ret = filled_list(0, self.num_life_cycles)
for pg in self.pool_groups:
for slot in pg.slots:
life_cycle = slot.life_cycle_id
assert len(slot.pages) == 1
page = slot.pages[0]
assert len(page.coalesced_buffers) == 1
scale = exact_div(page.slot_size, page.coalesced_buffers[0].single_buffer_size)
ret[life_cycle] = scale
return ret
def layer_to_life_cycle_ids(self) -> TypedIndexList[LayerId, LifeCycleId]:
map = dict[LayerId, LifeCycleId]()
for (layer_id, _), attr in self.buffer_attributes().items():
lc_id = map.setdefault(layer_id, attr.life_cycle_id)
assert lc_id == attr.life_cycle_id
assert len(map) == max(map.keys()) + 1
return cast(
TypedIndexList[LayerId, LifeCycleId],
[map[LayerId(layer_id)] for layer_id in typed_range(len(map))],
)
def __post_init__(self) -> None:
groups = [tuple(s.life_cycle_id for s in pg.slots) for pg in self.pool_groups]
all_life_cycle_ids = sum((g for g in groups), ())
assert len(all_life_cycle_ids) == len(set(all_life_cycle_ids))
def create_storage_config(config: KVCacheManagerConfig) -> StorageConfig:
# group buffers first by life cycle, then by single buffer size.
buffer_groups = defaultdict[LifeCycleId, defaultdict[int, list[BufferId]]](
lambda: defaultdict[int, list[BufferId]](list[BufferId])
)
life_cycle_registry = LifeCycleRegistry(config)
for layer in config.layers:
life_cycle = LifeCycle.make(
layer.window_size, layer.num_sink_tokens, config.tokens_per_block
)
life_cycle_id = life_cycle_registry.get_id(life_cycle)
size_to_buffers = buffer_groups[life_cycle_id]
for buffer in layer.buffers:
size_to_buffers[buffer.size].append(BufferId(layer.layer_id, buffer.role))
# Create one slot group for each life cycle.
# It's possible that buffers with different sizes form coalesced buffers with the same coalesced size.
# @TODO: add test for this case.
slot_groups: list[SlotConfig] = []
for life_cycle_id, size_to_buffers in buffer_groups.items():
assert len(set(len(buffer_ids) for buffer_ids in size_to_buffers.values())) == 1, (
"Not yet supported. While we can support this easily, we need to know whether the kernels "
"need to share page indices or not. We haven't seen such models, yet. So we leave this as a "
"future work."
)
size_to_coalesced_buffers = defaultdict[int, list[CoalescedBuffer]](list[CoalescedBuffer])
for size, buffer_ids in size_to_buffers.items():
coalesced_size = size * len(buffer_ids)
coalesced_buffers = size_to_coalesced_buffers[coalesced_size]
coalesced_buffers.append(
CoalescedBuffer(
life_cycle_id=life_cycle_id, single_buffer_size=size, buffer_ids=buffer_ids
)
)
slots = [
PageConfig(coalesced_buffers)
for coalesced_buffers in size_to_coalesced_buffers.values()
]
slots.sort(key=lambda p: p.slot_size, reverse=True)
slot_groups.append(SlotConfig(tuple(slots)))
# Merge slot groups with the same slot_size_list
pool_groups_by_slot_size_list = defaultdict[HomoTuple[int], list[SlotConfig]](list[SlotConfig])
for slot_group in slot_groups:
pool_groups_by_slot_size_list[slot_group.slot_size_list].append(slot_group)
pool_groups = [
PoolGroupConfig(tuple(slot_groups))
for slot_groups in pool_groups_by_slot_size_list.values()
]
return StorageConfig(cache_tiers=tuple(config.cache_tiers), pool_groups=tuple(pool_groups))

View File

@ -0,0 +1,936 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import errno
import os
import sys
import tempfile
import warnings
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import ClassVar, NewType, final
if sys.version_info[:2] >= (3, 12):
from typing import override
else:
from typing_extensions import override
from .._common import (
BAD_FILE_DESCRIPTOR,
NDEBUG,
Address,
CacheTier,
DiskAddress,
FileDescriptor,
MemAddress,
)
from .._cuda_virt_mem import PooledPhysMemAllocator, VirtMem
from .._exceptions import LogicError, OutOfPagesError, ResourceBusyError
from .._utils import (
CachedCudaEvent,
DynamicBitset,
HomoTuple,
HostMem,
assert_critical,
div_up,
query_total_gpu_memory,
remove_if,
resize_file,
round_down,
round_up,
)
PoolGroupIndex = NewType("PoolGroupIndex", int)
PoolIndex = NewType("PoolIndex", int)
SlotId = NewType("SlotId", int)
class SlotPoolBase(abc.ABC):
_slot_size: int
@property
def slot_size(self) -> int:
return self._slot_size
@property
@abc.abstractmethod
def num_slots(self) -> int: ...
@property
def num_bytes(self) -> int:
return self.slot_size * self.num_slots
def __init__(self, slot_size: int) -> None:
self._slot_size = slot_size
@abc.abstractmethod
def destroy(self) -> None:
pass
@abc.abstractmethod
def resize(self, new_num_slots: int) -> None:
pass
@abc.abstractmethod
def slot_address(self, slot: int) -> Address:
pass
def __del__(self) -> None:
self.destroy()
@final
class GpuSlotPool(SlotPoolBase):
__slots__ = ("_vm",)
_vm: VirtMem
def __init__(
self,
slot_size: int,
vm_size: int,
shared_phys_mem_pool: PooledPhysMemAllocator,
num_slots: int,
):
super().__init__(slot_size)
assert vm_size % shared_phys_mem_pool.phys_mem_size == 0
self._vm = VirtMem(vm_size, shared_phys_mem_pool)
self.resize(num_slots)
@override
def destroy(self) -> None:
self._vm.destroy()
@override
def resize(self, new_num_slots: int) -> None:
new_num_phys_mem = self._compute_num_phys_mem(
self.slot_size, new_num_slots, self._vm.phys_mem_size
)
self._vm.realloc(self._vm.phys_mem_size * new_num_phys_mem)
def extend_by_one_phys_mem(self) -> int:
self._vm.extend(1)
return self.num_slots
@override
def slot_address(self, slot: int) -> MemAddress:
return MemAddress(int(self._vm.address) + self.slot_size * slot)
@property
@override
def num_slots(self) -> int:
return self._compute_num_slots(
self.slot_size, self._vm.num_phys_mem, self._vm.phys_mem_size
)
@staticmethod
def _compute_num_phys_mem(slot_size: int, num_slots: int, phys_mem_size: int) -> int:
return div_up(num_slots * slot_size, phys_mem_size)
@staticmethod
def _compute_num_slots(slot_size: int, num_phys_mem: int, phys_mem_size: int) -> int:
return num_phys_mem * phys_mem_size // slot_size
class HostSlotPool(SlotPoolBase):
__slots__ = ("_host_mem",)
_host_mem: HostMem
def __init__(self, slot_size: int, num_slots: int) -> None:
super().__init__(slot_size)
self._host_mem = HostMem(self.aligned_size(num_slots))
@override
def destroy(self) -> None:
self._host_mem.destroy()
@override
def resize(self, new_num_slots: int) -> None:
self._host_mem.resize(self.aligned_size(new_num_slots))
@override
def slot_address(self, slot: int) -> MemAddress:
return MemAddress(self._host_mem._address + self.slot_size * slot)
@property
@override
def num_slots(self) -> int:
return self._host_mem.size // self.slot_size
def aligned_size(self, num_slots: int) -> int:
return round_up(num_slots * self.slot_size, HostMem.ALIGNMENT)
class DiskSlotPool(SlotPoolBase):
__slots__ = ("_filename", "_fd")
# Currently only used to get the parent folder where we create temporary files.
# You won't find file with this name.
filename: str
_fd: FileDescriptor
def __init__(self, filename: str, slot_size: int, num_slots: int) -> None:
super().__init__(slot_size)
self.filename = filename
folder = os.path.dirname(filename)
assert os.path.isdir(folder), f"Folder {folder} does not exist"
try:
fd = os.open(folder, os.O_TMPFILE | os.O_RDWR | os.O_EXCL, 0o664)
except OSError as e:
if e.errno != errno.EOPNOTSUPP:
raise
# Fallback for filesystems/architectures not supporting O_TMPFILE
fd, path = tempfile.mkstemp(dir=folder)
try:
os.unlink(path)
except OSError:
os.close(fd)
raise
self._fd = FileDescriptor(fd)
self.resize(num_slots)
@override
def destroy(self) -> None:
if self.fd == BAD_FILE_DESCRIPTOR:
return
os.close(self.fd)
self._fd = BAD_FILE_DESCRIPTOR
@property
def fd(self) -> FileDescriptor:
return self._fd
@property
def file_size(self) -> int:
return os.lseek(self.fd, 0, os.SEEK_END)
@override
def resize(self, new_num_slots: int) -> None:
file_size = new_num_slots * self.slot_size
resize_file(self.fd, file_size)
@override
def slot_address(self, slot: int) -> DiskAddress:
assert slot < self.num_slots
return DiskAddress(self.fd, slot * self.slot_size)
@property
@override
def num_slots(self) -> int:
return self.file_size // self.slot_size
@dataclass(slots=True)
class Slot:
# ready_event indicates whether the slot is ready for use.
# For newly allocated BlockData, it indicates finish of the last usage by the previous owners of the
# slot (who returned the slot to the pool).
# After data migration, it indicates finish of data migration.
# When passed to release(), it indicates finish of usage by the current owners of the slot.
_slot_id: SlotId | None
ready_event: CachedCudaEvent
@property
def slot_id(self) -> SlotId:
assert self._slot_id is not None
return self._slot_id
def query_ready(self) -> bool:
ev = self.ready_event
if ev is CachedCudaEvent.NULL:
return True
ret = ev.query_complete()
if ret:
self.ready_event = CachedCudaEvent.NULL
return ret
@property
def has_valid_slot(self) -> bool:
return self._slot_id is not None
def move_to_new_slot(self) -> "Slot":
ret = Slot(None, CachedCudaEvent.NULL)
ret.set_slot(self)
return ret
def set_slot(self, slot: "Slot") -> None:
if self.has_valid_slot:
raise LogicError("Slot is already set.")
self._slot_id = slot.slot_id
self.ready_event = slot.ready_event
slot._slot_id = None
slot.ready_event = CachedCudaEvent.NULL
def __del__(self) -> None:
if self.has_valid_slot:
warnings.warn("[KVCacheManager] slot is not freed before deletion")
class SlotAllocator:
__slots__ = (
"_capacity",
"_num_active_slots",
"_recycled_slots",
"_num_ready_recycled_slots",
"_occupied_mask",
"_target_capacity",
"_overflow_slots",
"_num_ready_overflow_slots",
)
_capacity: int
_num_active_slots: int # active slots are either in use or recycled.
_recycled_slots: deque[
Slot
] # only store recycled slots to avoid excessive memory usage on program start
_num_ready_recycled_slots: int # number of recycled slots that are ready to be used immediately
# (no need for sync or wait in stream), i.e. their ready events are triggered.
_occupied_mask: DynamicBitset
# for scheduled shrinking resize
_target_capacity: (
int # _target_capacity <= _capacity. Inequal if a shrinking resize is in progress.
)
_overflow_slots: list[
Slot
] # slots that will be out-of-range after a in-progress resize. scheduled for removal.
_num_ready_overflow_slots: int # similar to _num_ready_recycled_slots, but for _overflow_slots.
def __init__(self, capacity: int) -> None:
self._capacity = capacity
self._num_active_slots = 0
self._recycled_slots = deque[Slot]()
self._num_ready_recycled_slots = 0
self._occupied_mask = DynamicBitset(capacity)
self._target_capacity = capacity
self._overflow_slots = []
self._num_ready_overflow_slots = 0
def __del__(self) -> None:
assert_critical(
self._num_ready_recycled_slots == len(self._recycled_slots)
and self._num_ready_overflow_slots == len(self._overflow_slots),
"did you call synchronize()?",
)
assert_critical(
self._target_capacity == self._capacity and not self._overflow_slots,
"resize is in progress",
)
assert_critical(self._occupied_mask.num_set_bits == 0, "some slots are still in use")
assert_critical(
len(self._recycled_slots) == self._num_active_slots, "some slots are not free"
)
@property
def num_free_slots(self) -> int:
return len(self._recycled_slots) + max(self._target_capacity - self._num_active_slots, 0)
@property
def num_occupied_slots(self) -> int:
return self._occupied_mask.num_set_bits
def allocate(self) -> Slot:
if self.num_free_slots == 0:
raise OutOfPagesError("No free slots")
self._scrub_events()
# prefererence: ready recycled slots > new slots > recycled slots that are not ready
if self._num_ready_recycled_slots > 0:
assert self._recycled_slots
slot = self._recycled_slots.popleft()
assert slot.has_valid_slot
self._num_ready_recycled_slots -= 1
assert slot.ready_event is CachedCudaEvent.NULL
elif self._num_active_slots < self.num_slots:
slot = Slot(SlotId(self._num_active_slots), CachedCudaEvent.NULL)
self._num_active_slots += 1
else:
slot = self._recycled_slots.popleft()
assert slot.has_valid_slot
self._occupied_mask.set(slot.slot_id)
return slot
# The reason why we don't use allocate() multiple times is that if what user need is all or none,
# and when we don't have enough free slots, we will free these newly allocated slots by appending
# them to the back of the recycled slot queue, which may impact perf.
def allocate_multiple(self, num_slots: int) -> list[Slot]:
if self.num_free_slots < num_slots:
raise OutOfPagesError("Not enough free slots")
return [self.allocate() for _ in range(num_slots)]
def release(self, slot: Slot) -> None:
assert slot.has_valid_slot
slot = slot.move_to_new_slot()
if slot.slot_id >= self._capacity or not self._occupied_mask.get(slot.slot_id):
raise LogicError(f"Slot {slot.slot_id} is not occupied")
assert type(slot) is Slot and slot.has_valid_slot
if slot.slot_id < self._target_capacity:
self._recycled_slots.append(slot)
else:
self._overflow_slots.append(slot)
self._try_trigger_shrink()
self._occupied_mask.clear(slot.slot_id)
self._scrub_events()
assert NDEBUG or self._check()
@property
def num_slots(self) -> int:
return self._capacity
def resize(self, new_num_slots: int) -> None:
if self._target_capacity != self._capacity:
self.cancel_scheduled_resize()
assert NDEBUG or self._check()
old_num_slots = self.num_slots
if new_num_slots < self.num_slots and self._occupied_mask.any_set(
new_num_slots, self.num_slots
):
raise ResourceBusyError("resize cannot remove occupied slots")
self._occupied_mask.resize(new_num_slots)
self._capacity = new_num_slots
self._num_active_slots = min(self._num_active_slots, new_num_slots)
if new_num_slots < old_num_slots:
new_recycled_slots = deque[Slot]()
new_num_ready_recycled_slots = 0
for idx_recycled, slot in enumerate(self._recycled_slots):
assert type(slot) is Slot and slot.has_valid_slot
if slot.slot_id >= new_num_slots:
slot.ready_event.synchronize()
slot._slot_id = None
slot.ready_event = CachedCudaEvent.NULL
else:
new_recycled_slots.append(slot)
if idx_recycled < self._num_ready_recycled_slots:
new_num_ready_recycled_slots += 1
self._recycled_slots = new_recycled_slots
self._num_ready_recycled_slots = new_num_ready_recycled_slots
self._scrub_events()
self._target_capacity = self._capacity
assert NDEBUG or self._check()
def schedule_resize(self, new_num_slots: int) -> None:
assert NDEBUG or self._check()
if new_num_slots >= self.num_slots:
self.cancel_scheduled_resize()
self.resize(new_num_slots)
return
old_target_capacity = self._target_capacity
if new_num_slots > old_target_capacity:
self._recycled_slots.extend(
remove_if(
self._overflow_slots,
lambda slot: old_target_capacity <= slot.slot_id < new_num_slots,
)
)
self._num_ready_overflow_slots = 0
if new_num_slots < old_target_capacity:
self._overflow_slots.extend(
remove_if(self._recycled_slots, lambda slot: slot.slot_id >= new_num_slots)
)
self._num_ready_recycled_slots = 0
self._target_capacity = new_num_slots
self._try_trigger_shrink()
self._scrub_events()
assert NDEBUG or self._check()
def cancel_scheduled_resize(self) -> None:
assert NDEBUG or self._check()
self._target_capacity = self._capacity
self._recycled_slots.extend(remove_if(self._overflow_slots, lambda slot: True))
self._num_ready_overflow_slots = 0
def shrink_in_progress(self) -> bool:
"Indicates if a scheduled shrink is in progress."
assert self._target_capacity <= self._capacity
return self._target_capacity < self._capacity
def get_slots_blocking_shrink(self) -> HomoTuple[SlotId]:
return tuple(
SlotId(id)
for id in range(self._target_capacity, self._capacity)
if self._occupied_mask.get(id)
)
def _try_trigger_shrink(self) -> bool:
assert NDEBUG or self._check()
if (
self.shrink_in_progress()
and self._target_capacity + len(self._overflow_slots) == self._capacity
):
assert len(set(s.slot_id for s in self._overflow_slots)) == len(self._overflow_slots)
for slot in self._overflow_slots:
slot.ready_event.synchronize()
slot.ready_event = CachedCudaEvent.NULL
self._overflow_slots.clear()
self._num_ready_overflow_slots = 0
self._capacity = self._target_capacity
self._num_active_slots = min(self._num_active_slots, self._capacity)
self._scrub_events()
assert NDEBUG or self._check()
return True
return False
def _scrub_events(self) -> None:
self._num_ready_recycled_slots = self._scrub_events_impl(
self._recycled_slots, self._num_ready_recycled_slots
)
self._num_ready_overflow_slots = self._scrub_events_impl(
self._overflow_slots, self._num_ready_overflow_slots
)
def _check(self) -> bool:
return (
self._num_active_slots <= self._capacity
and self._target_capacity <= self._capacity
and (self.shrink_in_progress() or len(self._overflow_slots) == 0)
and all(
self._target_capacity <= slot.slot_id < self._capacity
for slot in self._overflow_slots
)
and len(self._recycled_slots) + len(self._overflow_slots) + self.num_occupied_slots
== self._num_active_slots
)
@staticmethod
def _scrub_events_impl(slots: Sequence[Slot], num_ready: int) -> int:
assert num_ready <= len(slots)
for i in range(num_ready, len(slots)):
slot = slots[i]
if slot.ready_event.query_complete():
slot.ready_event = CachedCudaEvent.NULL
num_ready += 1
else:
break
return num_ready
def _synchronize(self) -> None:
"synchronize the events of all unused slots"
while self._num_ready_recycled_slots != len(
self._recycled_slots
) or self._num_ready_overflow_slots != len(self._overflow_slots):
self._scrub_events()
class PoolGroupBase:
__slots__ = ("_slot_allocator", "_pools")
_slot_allocator: SlotAllocator
_pools: HomoTuple[SlotPoolBase]
def __init__(self, num_slots: int) -> None:
self._slot_allocator = SlotAllocator(num_slots)
def __del__(self) -> None:
self.destroy()
def destroy(self) -> None:
if self._slot_allocator._capacity == 0:
return
self._slot_allocator._synchronize()
for pool in self._pools:
pool.destroy()
self._slot_allocator.resize(0)
@property
def num_pools(self) -> PoolIndex:
return PoolIndex(len(self._pools))
@property
def num_slots(self) -> int:
num_slots = self._slot_allocator._capacity
assert num_slots <= self._get_num_slots_from_pools()
return num_slots
@property
def num_free_slots(self) -> int:
return self._slot_allocator.num_free_slots
@property
def num_bytes(self) -> int:
return sum(pool.num_bytes for pool in self._pools)
def resize_slot_allocator(self, new_num_slots: int | None) -> None:
"""
Resize the slot allocator, but not pools. If new_num_slots is None, make slot allocator match the pool sizes.
"""
if new_num_slots is None:
new_num_slots = self._get_num_slots_from_pools()
self._slot_allocator.resize(new_num_slots)
assert NDEBUG or self._check(True)
def resize_pools(self, new_num_slots: int | None) -> None:
"""
Resize the pools, but not the slot allocator. If new_num_slots is None, make pool sizes match
the slot allocator.
If exception is raised, size of pools may be imbalanced. Call resize_pools() again with None or
self._get_num_slots_from_pools() to fix it.
"""
if new_num_slots is None:
new_num_slots = self._slot_allocator.num_slots
for pool in self._pools:
pool.resize(new_num_slots)
assert NDEBUG or self._check(True)
def allocate(self) -> Slot:
return self._slot_allocator.allocate()
def allocate_multiple(self, num_slots: int) -> list[Slot]:
return self._slot_allocator.allocate_multiple(num_slots)
def release(self, slot: Slot) -> None:
self._slot_allocator.release(slot)
def slot_address(self, slot_id: SlotId) -> HomoTuple[Address]:
return tuple(pool.slot_address(slot_id) for pool in self._pools)
@property
def slot_size(self) -> HomoTuple[int]:
return tuple(pool.slot_size for pool in self._pools)
def _check(self, allow_mismatch: bool = False) -> bool:
pool_num_slots = self._get_num_slots_from_pools()
return (
self._slot_allocator.num_slots <= pool_num_slots
if allow_mismatch
else self._slot_allocator.num_slots == pool_num_slots
)
def _get_num_slots_from_pools(self) -> int:
return min(p.num_slots for p in self._pools)
@staticmethod
def _compute_num_phys_mem(
slot_size_list: Sequence[int], num_slots: int, phys_mem_size: int
) -> HomoTuple[int]:
return tuple(
GpuSlotPool._compute_num_phys_mem(slot_size, num_slots, phys_mem_size)
for slot_size in slot_size_list
)
class GpuPoolGroup(PoolGroupBase):
__slots__ = ()
def __init__(
self,
num_slots: int,
slot_size_list: Sequence[int],
shared_phys_mem_pool: PooledPhysMemAllocator,
):
super().__init__(num_slots)
total_gpu_memory = query_total_gpu_memory()
max_slot_size = max(slot_size_list)
phys_mem_size = shared_phys_mem_pool.phys_mem_size
self._pools = tuple(
GpuSlotPool(
slot_size,
round_down(int(total_gpu_memory * slot_size / max_slot_size), phys_mem_size),
shared_phys_mem_pool,
num_slots,
)
for slot_size in slot_size_list
)
class HostPoolGroup(PoolGroupBase):
__slots__ = ()
def __init__(self, num_slots: int, slot_size_list: Sequence[int]):
super().__init__(num_slots)
self._pools = tuple(HostSlotPool(slot_size, num_slots) for slot_size in slot_size_list)
class DiskPoolGroup(PoolGroupBase):
__slots__ = ()
def __init__(self, num_slots: int, slot_size_list: Sequence[int], filename_template: str):
super().__init__(num_slots)
self._pools = tuple(
DiskSlotPool(filename_template.format(i), slot_size, num_slots)
for i, slot_size in enumerate(slot_size_list)
)
class CacheLevelStorage:
TIER: ClassVar[CacheTier]
__slots__ = ("_total_quota", "_ratio_list", "_pool_groups")
_total_quota: int # fixme: remove _total_quota and _ratio_list and compute from _pool_groups
_ratio_list: HomoTuple[float]
_pool_groups: HomoTuple[PoolGroupBase]
def __init__(self, total_quota: int, ratio_list: Sequence[float]) -> None:
if not hasattr(self.__class__, "TIER"):
raise ValueError(f"{self.__class__.__name__} must define 'TIER' as a class variable")
self._total_quota = total_quota
self._ratio_list = tuple(ratio_list)
def __del__(self) -> None:
self.destroy()
@property
def cache_tier(self) -> CacheTier:
return self.TIER
def destroy(self) -> None:
if self._total_quota == 0:
return
for pg in self._pool_groups:
pg.destroy()
self._total_quota = 0
self._ratio_list = ()
def allocate(self, pool_group_index: PoolGroupIndex) -> Slot:
return self._pool_groups[pool_group_index].allocate()
def allocate_multiple(self, pool_group_index: PoolGroupIndex, num_slots: int) -> list[Slot]:
return self._pool_groups[pool_group_index].allocate_multiple(num_slots)
def release(self, pool_group_index: PoolGroupIndex, slot: Slot) -> None:
self._pool_groups[pool_group_index].release(slot)
@property
def total_quota(self) -> int:
return self._total_quota
@property
def ratio_list(self) -> HomoTuple[float]:
return self._ratio_list
def num_slots(self, pool_group_index: PoolGroupIndex) -> int:
return self._pool_groups[pool_group_index].num_slots
def get_num_free_slots(self, pool_group_index: PoolGroupIndex) -> int:
return self._pool_groups[pool_group_index].num_free_slots
@property
def slot_count_list(self) -> HomoTuple[int]:
"""
The number of slots in each pool group.
"""
return tuple(pg.num_slots for pg in self._pool_groups)
def slot_size(self, pool_group_index: PoolGroupIndex) -> HomoTuple[int]:
"""
The slot sizes of each pool in the pool group.
"""
return self._pool_groups[pool_group_index].slot_size
@property
def slot_size_lists(self) -> HomoTuple[HomoTuple[int]]:
"""
A tuple of tuples, each containing the slot sizes for a pool group.
"""
return tuple(tuple(p.slot_size for p in pg._pools) for pg in self._pool_groups)
@property
def num_pool_groups(self) -> PoolGroupIndex:
return PoolGroupIndex(len(self._pool_groups))
def slot_address(
self, pool_group_index: PoolGroupIndex, pool_index: PoolIndex, slot_id: SlotId
) -> Address:
return self._pool(pool_group_index, pool_index).slot_address(slot_id)
def resize(
self, new_total_quota: int | None = None, new_ratio_list: Sequence[float] | None = None
) -> None:
new_slot_count_list = self._compute_slot_count_list(new_total_quota, new_ratio_list)
self._resize_impl(new_slot_count_list)
if new_total_quota is not None:
self._total_quota = new_total_quota
if new_ratio_list is not None:
self._ratio_list = tuple(new_ratio_list)
def _resize_impl(self, new_slot_count_list: Sequence[int]) -> None:
old_slot_count_list = self.slot_count_list
assert old_slot_count_list == self._compute_slot_count_list(
self.total_quota, self.ratio_list
)
try:
# shrink first to avoid intermediate state with excessive memory usage
for pg, new_slot_count, old_slot_count in zip(
self._pool_groups, new_slot_count_list, old_slot_count_list
):
if new_slot_count < old_slot_count:
pg.resize_slot_allocator(
new_slot_count
) # shrink slot allocators first as it can fail for shrinking
pg.resize_pools(new_slot_count)
for pg, new_slot_count, old_slot_count in zip(
self._pool_groups, new_slot_count_list, old_slot_count_list
):
if new_slot_count > old_slot_count:
pg.resize_pools(
new_slot_count
) # expand pools first as it can fail for expanding
pg.resize_slot_allocator(new_slot_count)
except Exception:
self._resize_impl(old_slot_count_list)
raise
def _pool(self, pool_group_index: PoolGroupIndex, pool_index: PoolIndex) -> SlotPoolBase:
return self._pool_groups[pool_group_index]._pools[pool_index]
# Calculate how many slots will there be in each pool group with the given total_quota and
# ratio_list. Use _ratio_to_slot_count_list for initialization.
def _compute_slot_count_list(
self, total_quota: int | None = None, ratio_list: Sequence[float] | None = None
) -> HomoTuple[int]:
if total_quota is None:
total_quota = self.total_quota
if ratio_list is None:
ratio_list = self.ratio_list
assert len(ratio_list) == len(self._pool_groups), (
f"Wrong ratio_list length. Expected {len(self._pool_groups)}, got {len(ratio_list)}"
)
return self._ratio_to_slot_count_list(
total_quota, self.slot_size_lists, ratio_list, self.pool_size_granularity
)
@staticmethod
def _ratio_to_slot_count_list(
total_quota: int,
slot_size_lists: Sequence[Sequence[int]],
ratio_list: Sequence[float],
pool_size_granularity: int,
) -> HomoTuple[int]:
num_pool_groups = len(ratio_list)
assert num_pool_groups == len(slot_size_lists)
assert total_quota % pool_size_granularity == 0
total_grains = total_quota // pool_size_granularity
assert total_grains >= sum(len(sizes) for sizes in slot_size_lists)
remaining_grains = total_grains
granularity = pool_size_granularity
slot_cnt_list = [0] * num_pool_groups
# divide total_quota into pool groups based on init_ratio, then divide quote for each pool_group
# into pools based on slot_size.
pg_idx_lst = sorted(range(len(ratio_list)), key=lambda i: ratio_list[i])
for i, pg in enumerate(pg_idx_lst):
slot_size_list = slot_size_lists[pg]
min_pool_grains = [div_up(s, granularity) for s in slot_size_list]
pct: float = ratio_list[pg] / sum(ratio_list[j] for j in pg_idx_lst[i:])
pg_grains = max(round(remaining_grains * pct), sum(min_pool_grains))
num_slots: int = 1 << 63
remaining_pg_grains = pg_grains
pool_idx_lst = sorted(range(len(slot_size_list)), key=lambda i: slot_size_list[i])
for j, pool in enumerate(pool_idx_lst):
slot_size = slot_size_list[pool]
pool_grains = max(
min_pool_grains[pool],
round(
remaining_pg_grains
* (slot_size / sum(slot_size_list[k] for k in pool_idx_lst[j:]))
),
)
num_slots = min(num_slots, pool_grains * granularity // slot_size)
remaining_pg_grains -= pool_grains
assert remaining_pg_grains == 0
assert num_slots > 0
slot_cnt_list[pg] = num_slots
remaining_grains -= pg_grains
assert remaining_grains == 0
return tuple(slot_cnt_list)
@property
def pool_size_granularity(self) -> int:
return 2 << 20
class GpuCacheLevelStorage(CacheLevelStorage):
TIER: ClassVar[CacheTier] = CacheTier.GPU_MEM
__slots__ = ("shared_phys_mem_pool",)
shared_phys_mem_pool: PooledPhysMemAllocator
def __init__(
self,
total_quota: int,
slot_size_lists: Sequence[Sequence[int]],
init_ratio: Sequence[float],
phys_mem_size: int,
):
assert len(slot_size_lists) == len(init_ratio), (
"slot_size_lists and init_ratio must have the same length"
)
super().__init__(total_quota, init_ratio)
slot_count_list = self._ratio_to_slot_count_list(
total_quota, slot_size_lists, init_ratio, phys_mem_size
)
self.shared_phys_mem_pool = PooledPhysMemAllocator(phys_mem_size)
self._pool_groups = tuple(
GpuPoolGroup(num_slots, slot_size_list, self.shared_phys_mem_pool)
for slot_size_list, num_slots in zip(slot_size_lists, slot_count_list)
)
@override
def resize(
self, new_total_quota: int | None = None, new_ratio_list: Sequence[float] | None = None
):
super().resize(new_total_quota, new_ratio_list)
self.shared_phys_mem_pool.clear() # clear cached unused phys mem
@property
def pool_size_granularity(self) -> int:
return self.shared_phys_mem_pool.phys_mem_size
class HostCacheLevelStorage(CacheLevelStorage):
TIER: ClassVar[CacheTier] = CacheTier.HOST_MEM
POOL_SIZE_GRANULARITY: ClassVar[int] = HostMem.ALIGNMENT
__slots__ = ()
def __init__(
self,
total_quota: int,
slot_size_lists: Sequence[Sequence[int]],
init_ratio: Sequence[float],
):
super().__init__(total_quota, init_ratio)
slot_count_list = self._ratio_to_slot_count_list(
total_quota, slot_size_lists, init_ratio, self.pool_size_granularity
)
self._pool_groups = tuple(
HostPoolGroup(num_slots, slot_size_list)
for slot_size_list, num_slots in zip(slot_size_lists, slot_count_list)
)
@property
def pool_size_granularity(self) -> int:
return self.POOL_SIZE_GRANULARITY
class DiskCacheLevelStorage(CacheLevelStorage):
__slots__ = ()
TIER: ClassVar[CacheTier] = CacheTier.DISK
POOL_SIZE_GRANULARITY: ClassVar[int] = 2 << 20
def __init__(
self,
total_quota: int,
slot_size_lists: Sequence[Sequence[int]],
init_ratio: Sequence[float],
filename_template: str,
):
super().__init__(total_quota, init_ratio)
slot_count_list = self._ratio_to_slot_count_list(
total_quota, slot_size_lists, init_ratio, self.pool_size_granularity
)
self._pool_groups = tuple(
DiskPoolGroup(num_slots, slot_size_list, filename_template.format(pg_idx, "{}"))
for pg_idx, (slot_size_list, num_slots) in enumerate(
zip(slot_size_lists, slot_count_list)
)
)
@property
def pool_size_granularity(self) -> int:
return self.POOL_SIZE_GRANULARITY

View File

@ -0,0 +1,518 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import warnings
from dataclasses import dataclass
from typing import Iterator, Sequence, cast
from . import rawref
from ._common import (
GPU_LEVEL,
NDEBUG,
Address,
CacheLevel,
CacheTier,
LayerId,
MemAddress,
PageIndex,
PageStatus,
)
from ._config import CacheTierConfig, DataRole, DiskCacheTierConfig
from ._copy_engine import CopyTask, batched_copy
from ._eviction_controller import EvictablePage, PerLevelEvictionController
from ._exceptions import OutOfPagesError
from ._life_cycle_registry import LifeCycleId, LifeCycleRegistry
from ._page import Page
from ._storage import CacheLevelStorage
from ._storage._config import BufferAttr, BufferId, StorageConfig
from ._storage._core import (
DiskCacheLevelStorage,
GpuCacheLevelStorage,
HostCacheLevelStorage,
PoolGroupBase,
PoolGroupIndex,
PoolIndex,
Slot,
SlotId,
)
from ._utils import (
Array2D,
CachedCudaEvent,
HomoTuple,
TemporaryCudaStream,
TypedIndexList,
filled_array2d,
filled_list,
get_uniform_attribute,
make_typed,
map_optional,
partition,
remove_if,
round_up,
typed_enumerate,
typed_range,
)
class CacheLevelManager:
__slots__ = ("cache_level", "storage", "controller")
cache_level: CacheLevel
storage: CacheLevelStorage
controller: PerLevelEvictionController
@property
def cache_tier(self) -> CacheTier:
return self.storage.cache_tier
def __init__(
self,
life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex],
cache_level: CacheLevel,
config: CacheTierConfig,
slot_size_lists: Sequence[Sequence[int]],
init_ratio: Sequence[float],
):
self.cache_level = cache_level
self.storage = self._create_cache_level_storage(config, slot_size_lists, init_ratio)
self.controller = PerLevelEvictionController(life_cycle_grouping, cache_level)
@property
def num_pool_groups(self) -> PoolGroupIndex:
assert self.storage.num_pool_groups == self.controller.num_pool_groups
return self.storage.num_pool_groups
@staticmethod
def _create_cache_level_storage(
config: CacheTierConfig,
slot_size_lists: Sequence[Sequence[int]],
init_ratio: Sequence[float],
) -> CacheLevelStorage:
quota = config.quota
num_pools = sum(len(sizes) for sizes in slot_size_lists)
def adjust_quota(quota: int, granularity: int) -> int:
return max(granularity * num_pools, round_up(quota, granularity))
if config.tier == CacheTier.GPU_MEM:
page_size = 2 << 20
phys_mem_size = page_size << min(4, max(0, int(math.log(quota / (page_size * 512), 2))))
quota = adjust_quota(quota, phys_mem_size)
return GpuCacheLevelStorage(quota, slot_size_lists, init_ratio, phys_mem_size)
elif config.tier == CacheTier.HOST_MEM:
quota = adjust_quota(quota, HostCacheLevelStorage.POOL_SIZE_GRANULARITY)
return HostCacheLevelStorage(quota, slot_size_lists, init_ratio)
elif config.tier == CacheTier.DISK:
assert isinstance(config, DiskCacheTierConfig)
assert os.path.isdir(config.path), (
f"Disk path {config.path} does not exist or is not a directory"
)
quota = adjust_quota(quota, DiskCacheLevelStorage.POOL_SIZE_GRANULARITY)
filename_template = os.path.join(config.path, "g{}p{}.bin")
return DiskCacheLevelStorage(quota, slot_size_lists, init_ratio, filename_template)
else:
raise ValueError(f"Invalid cache tier: {config.tier}")
@dataclass(slots=True, frozen=True)
class StorageStatistics:
"All in number of slots, for one pool group"
slot_size: HomoTuple[int]
total: int
free: int
evictable: int
@property
def available(self) -> int:
return self.free + self.evictable
@property
def unavailable(self) -> int:
return self.total - self.available
class StorageManager:
__slots__ = (
"_life_cycles",
"_layer_to_life_cycle_ids",
"_slot_to_page_indices",
"_buffer_attr",
"_life_cycle_grouping",
"_levels",
"_cached_num_pool_groups",
"__rawref__",
)
_life_cycles: LifeCycleRegistry
_layer_to_life_cycle_ids: TypedIndexList[LayerId, LifeCycleId]
_slot_to_page_indices: TypedIndexList[LifeCycleId, int]
_buffer_attr: dict[BufferId, BufferAttr]
_life_cycle_grouping: TypedIndexList[LifeCycleId, PoolGroupIndex]
_levels: TypedIndexList[CacheLevel, CacheLevelManager]
_cached_num_pool_groups: PoolGroupIndex
__rawref__: rawref.ref["StorageManager"]
def __init__(self, life_cycles: LifeCycleRegistry, config: StorageConfig) -> None:
self.__rawref__ = rawref.NULL
assert config.cache_tiers[GPU_LEVEL].tier == CacheTier.GPU_MEM, (
"The first cache tier must be GPU memory"
)
self._life_cycles = life_cycles
self._layer_to_life_cycle_ids = config.layer_to_life_cycle_ids()
self._slot_to_page_indices = config.slot_to_page_indices()
self._buffer_attr = config.buffer_attributes()
self._life_cycle_grouping = config.life_cycle_grouping()
slot_size_lists = [pg.slot_size_list for pg in config.pool_groups]
# @TODO: accept an optional avg_seq_len param and consider sliding window.
init_ratio = [float(sum(pg.slot_size_list) * len(pg.slots)) for pg in config.pool_groups]
total = sum(init_ratio)
init_ratio = [x / total for x in init_ratio]
num_levels = CacheLevel(len(config.cache_tiers))
self._levels = cast(
TypedIndexList,
[
CacheLevelManager(
self._life_cycle_grouping, i, config.cache_tiers[i], slot_size_lists, init_ratio
)
for i in typed_range(num_levels)
],
)
self._cached_num_pool_groups = get_uniform_attribute(
self._levels, lambda level: level.storage.num_pool_groups
)
def __del__(self) -> None:
self.__rawref__.invalidate()
def get_pool_group_index(self, life_cycle: LifeCycleId) -> PoolGroupIndex:
return self._life_cycle_grouping[life_cycle]
def new_gpu_slots(
self, num_slots: TypedIndexList[LifeCycleId, int]
) -> TypedIndexList[LifeCycleId, list[Slot]]:
return self.new_slots(GPU_LEVEL, num_slots)
def new_slots(
self, level: CacheLevel, num_slots: TypedIndexList[LifeCycleId, int]
) -> TypedIndexList[LifeCycleId, list[Slot]]:
pg_num_slots = filled_list(0, self.num_pool_groups)
for lc in typed_range(self.num_life_cycles):
pg_num_slots[self.get_pool_group_index(lc)] += num_slots[lc]
storage = self._levels[level].storage
if any(
pg_num_slots[pg] > storage.get_num_free_slots(pg)
for pg in typed_range(self.num_pool_groups)
):
self.prepare_free_slots(level, pg_num_slots)
assert all(
pg_num_slots[pg] <= storage.get_num_free_slots(pg)
for pg in typed_range(self.num_pool_groups)
)
ret = filled_list(list[Slot](), self.num_life_cycles)
try:
for life_cycle in typed_range(self.num_life_cycles):
pg_idx = self.get_pool_group_index(life_cycle)
ret[life_cycle] = storage.allocate_multiple(pg_idx, num_slots[life_cycle])
except Exception:
warnings.warn("Exception not expected here. Please report a bug.")
for lc, slots in typed_enumerate(ret):
pg_idx = self.get_pool_group_index(lc)
for s in slots:
storage.release(pg_idx, s)
raise
return ret
@property
def life_cycles(self) -> LifeCycleRegistry:
return self._life_cycles
@property
def num_life_cycles(self) -> LifeCycleId:
return LifeCycleId(len(self._life_cycle_grouping))
@property
def num_pool_groups(self) -> PoolGroupIndex:
return self._cached_num_pool_groups
@property
def num_cache_levels(self) -> CacheLevel:
return CacheLevel(len(self._levels))
def is_last_level(self, level: CacheLevel) -> bool:
return level == self.num_cache_levels - 1
@property
def cache_tiers(self) -> HomoTuple[CacheTier]:
return tuple(cache_level.cache_tier for cache_level in self._levels)
def is_evictable(self, page: EvictablePage, level: CacheLevel | None = None) -> bool:
"""
Check if a page is evictable. If level is specified, check if the page will be evictable after
migrating to the given level.
"""
status = page.status
level = page.cache_level if level is None else level
# droppable pages that are not committed should be dropped immediately.
# held pages in last level cache can't be evicted.
return (status == PageStatus.DROPPABLE and page.is_committed()) or (
status == PageStatus.HELD and level < self.num_cache_levels - 1
)
def prepare_free_slots(
self, level: CacheLevel, requirements: TypedIndexList[PoolGroupIndex, int]
) -> None:
goals = filled_array2d(self.num_cache_levels, self.num_pool_groups, 0)
for pg in typed_range(self.num_pool_groups):
goals[level, pg] = requirements[pg]
fallen_pages = make_typed(lambda: list[Page](), self.num_pool_groups)
self._prepare_free_slots(goals, level, fallen_pages)
def _prepare_free_slots(
self,
goals: Array2D[CacheLevel, PoolGroupIndex, int],
lvl_id: CacheLevel,
fallen_pages: TypedIndexList[PoolGroupIndex, list[Page]],
) -> None:
assert NDEBUG or goals.rows == self.num_cache_levels and goals.cols == self.num_pool_groups
assert NDEBUG or all(
all(p.cache_level < lvl_id for p in pages) for pages in fallen_pages
), "Fallen pages must come from upper cache levels"
storage = self._levels[lvl_id].storage
ctrl = self._levels[lvl_id].controller
num_to_evict = filled_list(0, self.num_pool_groups)
held_pages = make_typed(lambda: list[Page](), self.num_pool_groups)
for pg_idx in typed_range(self.num_pool_groups):
goal = goals[lvl_id, pg_idx]
fallen = len(fallen_pages[pg_idx])
old_free_cnt = storage.get_num_free_slots(pg_idx)
evictable_cnt = ctrl.num_evictable_pages(pg_idx)
num_to_evict[pg_idx] = max(0, min(goal + fallen - old_free_cnt, evictable_cnt))
fallen_held_cnt = 0 # fallen held pages we must accept in the current level.
if self.is_last_level(lvl_id):
held_pages[pg_idx] = remove_if(
fallen_pages[pg_idx], lambda p: p.status == PageStatus.HELD
)
fallen_held_cnt = len(held_pages[pg_idx])
if fallen_held_cnt > old_free_cnt + evictable_cnt:
# Do we need to revert the eviction we did before? Maybe not.
raise OutOfPagesError(
"Too many held pages are being evicted to the last-level cache for group {pg_idx}"
)
if old_free_cnt + evictable_cnt - fallen_held_cnt < goal:
raise OutOfPagesError(
"Impossible to meet the goal ({goal} free slots) for group {pg_idx}"
)
evicted = ctrl.evict(num_to_evict)
accepted_pages = make_typed(lambda: list[Page](), self.num_pool_groups)
is_last_level = self.is_last_level(lvl_id)
if is_last_level:
for pg_idx in typed_range(self.num_pool_groups):
old_free_cnt = storage.get_num_free_slots(pg_idx)
num_evicted = len(evicted[pg_idx])
assert NDEBUG or all(p.status == PageStatus.DROPPABLE for p in evicted[pg_idx])
if not NDEBUG:
dbg_rawrefs = [rawref.ref(p) for p in evicted[pg_idx]]
evicted[pg_idx].clear()
if not NDEBUG:
assert all(p() is None for p in dbg_rawrefs) # pyright: ignore
new_free_cnt = storage.get_num_free_slots(pg_idx)
# GC of some pages may trigger removal of radix tree blocks and some other pages.
assert new_free_cnt >= num_evicted + old_free_cnt
assert len(held_pages[pg_idx]) <= new_free_cnt
fallen_pages[pg_idx].extend(held_pages[pg_idx])
held_pages[pg_idx].clear()
goal = goals[lvl_id, pg_idx]
num_accepted = min(new_free_cnt - goal, len(fallen_pages[pg_idx]))
assert num_accepted >= 0
accepted_pages[pg_idx] = (
fallen_pages[pg_idx][-num_accepted:] if num_accepted > 0 else []
)
fallen_pages[pg_idx].clear()
else:
assert all(len(g) == 0 for g in held_pages)
for pg_idx in typed_range(self.num_pool_groups):
old_free_cnt = storage.get_num_free_slots(pg_idx)
e = evicted[pg_idx]
num_evicted = len(e)
fallen_pages[pg_idx][:0] = cast(list[Page], e)
e.clear()
num_accepted = min(
old_free_cnt + num_evicted - goals[lvl_id, pg_idx], len(fallen_pages[pg_idx])
)
assert num_accepted >= 0
if num_accepted > 0:
accepted_pages[pg_idx] = fallen_pages[pg_idx][-num_accepted:]
del fallen_pages[pg_idx][-num_accepted:]
self._prepare_free_slots(goals, CacheLevel(lvl_id + 1), fallen_pages)
assert all(len(f) == 0 for f in fallen_pages)
# migrate pages
for pg_idx in typed_range(self.num_pool_groups):
partitioned = partition(
accepted_pages[pg_idx],
lambda p: (p.cache_level, self.get_pool_group_index(p.life_cycle)),
)
accepted_pages[pg_idx].clear()
for (src_lvl, pg_idx), pages in partitioned.items():
dst_lvl = lvl_id
self._batched_migrate(pg_idx, dst_lvl, src_lvl, pages, update_src=True)
for p in pages:
if is_last_level and p.status == PageStatus.HELD:
continue
self._levels[dst_lvl].controller.schedule_for_eviction(p)
return
def _batched_migrate(
self,
pool_group_index: PoolGroupIndex,
dst_level: CacheLevel,
src_level: CacheLevel,
src_pages: Sequence[Page],
update_src: bool,
) -> Sequence[Slot] | None:
"Free slots must be prepared before calling this function."
assert dst_level != src_level, "dst_level and src_level must be different"
num_slots = len(src_pages)
num_pools = self.num_pools(pool_group_index)
src_pool_group = self._pool_group(src_level, pool_group_index)
dst_pool_group = self._pool_group(dst_level, pool_group_index)
if dst_pool_group.num_free_slots < num_slots:
raise OutOfPagesError("Not enough free slots")
dst_slots = dst_pool_group.allocate_multiple(num_slots)
try:
assert len(dst_slots) == num_slots
prior_events: set[CachedCudaEvent] = set()
tasks_per_pool: list[list[CopyTask]] = [[]] * num_pools
for src, dst in zip(src_pages, dst_slots):
assert src.node_ref is None
prior_events.update((dst.ready_event, src.ready_event))
dst_addresses = dst_pool_group.slot_address(dst.slot_id)
src_addresses = src_pool_group.slot_address(src.slot_id)
for pool_idx in range(num_pools):
tasks_per_pool[pool_idx].append(
CopyTask(dst_addresses[pool_idx], src_addresses[pool_idx])
)
dst_tier = self._levels[dst_level].cache_tier
src_tier = self._levels[src_level].cache_tier
with TemporaryCudaStream(prior_events) as stream:
slot_sizes = self.slot_size(pool_group_index)
for pool_idx, tasks in enumerate(tasks_per_pool):
batched_copy(dst_tier, src_tier, slot_sizes[pool_idx], tasks, stream.get())
finish_event = stream.take_finish_event()
for src, dst in zip(src_pages, dst_slots):
dst.ready_event = finish_event
src.ready_event = (
finish_event # compulsory for the next owner getting this slot from the pool.
)
if update_src:
scheduled_for_eviction = src.scheduled_for_eviction
if scheduled_for_eviction:
self.exclude_from_eviction(src)
src_pool_group.release(src)
src.set_slot(dst)
src.cache_level = dst_level
if scheduled_for_eviction:
self.schedule_for_eviction(src)
return None if update_src else dst_slots
except Exception:
for s in dst_slots:
dst_pool_group.release(s)
raise
def _pool_group(
self, cache_level: CacheLevel, pool_group_index: PoolGroupIndex
) -> PoolGroupBase:
return self._levels[cache_level].storage._pool_groups[pool_group_index]
def num_pools(self, pool_group_index: PoolGroupIndex) -> PoolIndex:
return get_uniform_attribute(
self._levels, lambda level: level.storage._pool_groups[pool_group_index].num_pools
)
def slot_size(self, pool_group_index: PoolGroupIndex) -> HomoTuple[int]:
return get_uniform_attribute(
self._levels, lambda level: level.storage.slot_size(pool_group_index)
)
def num_slots(
self, pool_group_index: PoolGroupIndex, cache_level: CacheLevel = GPU_LEVEL
) -> int:
return self._levels[cache_level].storage.num_slots(pool_group_index)
def release_slot(self, life_cycle: LifeCycleId, cache_level: CacheLevel, slot: Slot) -> None:
pg_idx = self.get_pool_group_index(life_cycle)
self._levels[cache_level].storage.release(pg_idx, slot)
def schedule_for_eviction(self, page: EvictablePage) -> None:
if self.is_evictable(page):
self._levels[page.cache_level].controller.schedule_for_eviction(page)
def exclude_from_eviction(self, page: EvictablePage) -> None:
assert page.node_ref is not None
self._levels[page.cache_level].controller.remove(page.node_ref)
def get_mem_pool_base_address(self, layer_id: LayerId, data_role: DataRole) -> MemAddress:
storage = self._levels[GPU_LEVEL].storage
attr = self.get_buffer_attr(layer_id, data_role)
pg_idx = self.get_pool_group_index(attr.life_cycle_id)
return MemAddress(
cast(int, storage.slot_address(pg_idx, attr.pool_index, SlotId(0))) + attr.offset
)
def get_page_indices_ref(
self, lc_id: LifeCycleId, pages: Iterator[Page | None]
) -> Iterator[int | None]:
"Reference implementation. Not fast enough for production."
scale = self._slot_to_page_indices[lc_id]
return (map_optional(page, lambda p: scale * int(p.slot_id)) for page in pages)
def get_buffer_attr(self, layer_id: LayerId, data_role: DataRole) -> BufferAttr:
return self._buffer_attr[BufferId(layer_id, data_role)]
def slot_address(
self, level: CacheLevel, pg_idx: PoolGroupIndex, slot_id: SlotId, pool_idx: PoolIndex
) -> Address:
return self._levels[level].storage.slot_address(pg_idx, pool_idx, slot_id)
def get_page_indices_for_slot(self, life_cycle: LifeCycleId, slot_id: SlotId) -> PageIndex:
scale = self._slot_to_page_indices[life_cycle]
return PageIndex(scale * slot_id)
def get_statistics(
self, level: CacheLevel = GPU_LEVEL
) -> TypedIndexList[PoolGroupIndex, StorageStatistics]:
ret = make_typed(lambda: StorageStatistics((), 0, 0, 0), self.num_pool_groups)
for pg_idx in typed_range(self.num_pool_groups):
pg = self._pool_group(level, pg_idx)
evictable_cnt = self._levels[level].controller.num_evictable_pages(pg_idx)
ret[pg_idx] = StorageStatistics(
pg.slot_size, pg.num_slots, pg.num_free_slots, evictable_cnt
)
return ret
def get_utilization(
self, level: CacheLevel = GPU_LEVEL
) -> TypedIndexList[PoolGroupIndex, float]:
ret = make_typed(lambda: 0.0, self.num_pool_groups)
stats = self.get_statistics(level)
for pg_idx in typed_range(self.num_pool_groups):
ret[pg_idx] = stats[pg_idx].unavailable / stats[pg_idx].total
return ret
def get_overall_utilization(self, level: CacheLevel = GPU_LEVEL) -> float:
stats = self.get_statistics(level)
return sum(sum(s.slot_size) * s.unavailable for s in stats) / sum(
sum(s.slot_size) * s.total for s in stats
)

View File

@ -0,0 +1,945 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import array
import ctypes
import errno
import functools
import itertools
import operator
import os
import platform
import traceback
import warnings
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Set
from ctypes.util import find_library
from itertools import pairwise
from typing import (
Any,
Callable,
ClassVar,
Final,
Generic,
Iterable,
Iterator,
MutableSequence,
Protocol,
Reversible,
Sequence,
Type,
TypeVar,
cast,
)
import cuda.bindings.driver as drv
import cuda.bindings.runtime as cudart
from . import rawref
from ._common import NDEBUG, CudaStream
from ._exceptions import CuError, CuOOMError, DiskOOMError, HostOOMError
T = TypeVar("T")
U = TypeVar("U")
Index = TypeVar("Index", bound=int, contravariant=True)
IndexO = TypeVar("IndexO", bound=int, covariant=True)
Row = TypeVar("Row", bound=int)
Col = TypeVar("Col", bound=int)
def _unwrap(
ret: drv.CUresult
| tuple[
drv.CUresult,
T,
]
| tuple[drv.CUresult, T, U],
):
if isinstance(ret, drv.CUresult):
if int(ret) != int(drv.CUresult.CUDA_SUCCESS): # pyright: ignore
if int(ret) == int(drv.CUresult.CUDA_ERROR_OUT_OF_MEMORY): # pyright: ignore
raise CuOOMError()
raise CuError(ret)
else:
_unwrap(ret[0])
return ret[1] if len(ret) == 2 else ret[1:]
def div_up(x: int, y: int) -> int:
return (x + y - 1) // y
def round_up(x: int, y: int) -> int:
return div_up(x, y) * y
def round_down(x: int, y: int) -> int:
return x // y * y
def in_range(x: int, lower: int, upper: int) -> bool:
return lower <= x < upper
def exact_div(x: int, y: int) -> int:
assert x % y == 0
return x // y
def overlap(a: tuple[Index, Index], b: tuple[Index, Index]) -> tuple[Index, Index] | tuple[()]:
"Returns the overlap of two ranges, or an empty tuple if they do not overlap."
return (max(a[0], b[0]), min(a[1], b[1])) if a[0] < b[1] and b[0] < a[1] else ()
def value_or(opt: T | None, default: T) -> T:
return default if opt is None else opt
def unwrap_optional(value: T | None) -> T:
if value is not None:
return value
raise ValueError("Expected non-None value")
def unwrap_weakref(ref: weakref.ref[T]) -> T:
obj = ref()
if obj is not None:
return obj
raise ValueError("Dereferencing a dangling weakref")
def unwrap_rawref(ref: rawref.ref[T]) -> T:
obj = ref()
if obj is not None:
return obj
raise ValueError("Dereferencing a dangling rawref")
def map_optional(value: T | None, func: Callable[[T], U]) -> U | None:
return func(value) if value is not None else None
def remove_if(original: MutableSequence[T], predicate: Callable[[T], bool]) -> list[T]:
"Remove items from original that satisfy the predicate and return the removed items."
removed = []
for idx, item in enumerate(original):
if predicate(item):
removed.append(item)
else:
original[idx - len(removed)] = item
del original[len(original) - len(removed) :]
return removed
def chunked(iterable: Iterable[T], size: int) -> Iterator[list[T]]:
iterator = iter(iterable)
while True:
chunk = list(itertools.islice(iterator, size))
if not chunk:
break
yield chunk
def partition(original: Iterable[T], classifier: Callable[[T], U]) -> defaultdict[U, list[T]]:
ret = defaultdict(list)
for item in original:
ret[classifier(item)].append(item)
return ret
def get_uniform_attribute(iterable: Iterable[T], attribute_func: Callable[[T], U]) -> U:
ret = attribute_func(next(iter(iterable)))
assert NDEBUG or all(attribute_func(item) == ret for item in iterable)
return ret
def assert_critical(condition: bool, message: str | None = None) -> None:
"Similar to assert, but instead of raising an exception, it terminates the process, even if inside __del__()."
if not condition:
warnings.warn(value_or(message, "Critical assertion failed"))
traceback.print_stack()
os._exit(1)
def noexcept(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> T:
try:
return func(*args, **kwargs)
except Exception as e:
raise AssertionError(f"Function {func.__name__} raised an exception: {e}") from e
return wrapper
def not_implemented(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> T:
raise NotImplementedError(f"The function '{func.__name__}' is not implemented yet.")
return wrapper
def expect_type(ExpectedType: Type[T], value: Any) -> T:
"Similar to typing.cast, but does runtime checking with assert."
assert isinstance(value, ExpectedType), f"Expected {ExpectedType}, got {type(value)}"
return value
def is_sorted(
iterable: Iterable[T], key: Callable[[T], Any] = lambda x: x, reverse: bool = False
) -> bool:
comp = operator.ge if reverse else operator.le
return all(comp(key(a), key(b)) for a, b in pairwise(iterable))
HomoTuple = tuple[T, ...]
class TypedIndexList(Protocol[Index, T]):
"""
A protocol representing a list-like container with a strongly typed integer index.
Useful for enforcing index types like NewType wrappers around int.
"""
def __getitem__(self, index: Index) -> T: ...
def __setitem__(self, index: Index, value: T) -> None: ...
def __delitem__(self, index: Index | slice) -> None: ...
def __iter__(self) -> Iterator[T]: ...
def __len__(self) -> int: ...
def __reversed__(self) -> Iterator[T]: ...
def clear(self) -> None: ...
def pop(self) -> T: ...
def append(self, value: T) -> None: ...
# @TODO: use this where applicable.
def to_typed(index_type: Type[Index], lst: list[T]) -> TypedIndexList[Index, T]:
"""
Casts a standard list to a TypedIndexList with a strongly typed integer index.
Parameters:
index_type: A type alias for the NewType index, e.g. type(BlockOrdinal(0)) or a concrete class derived from int.
lst: The list to cast
Returns:
A TypedIndexList[Index, T] with the specified index type
"""
return cast(TypedIndexList[Index, T], lst)
def filled_list(value: T, count: Index) -> TypedIndexList[Index, T]:
"Note that all elements will be the same value. Do not use mutable values."
return cast(TypedIndexList[Index, T], [value] * int(count))
def make_typed(generator: Callable[[], T], count: Index) -> TypedIndexList[Index, T]:
return cast(TypedIndexList[Index, T], [generator() for _ in range(int(count))])
def typed_len(iterable: TypedIndexList[IndexO, T]) -> IndexO:
return cast(IndexO, len(iterable))
def typed_enumerate(iterable: TypedIndexList[Index, T]) -> Iterator[tuple[Index, T]]:
return cast(Iterator[tuple[Index, T]], enumerate(iterable))
def typed_map(
iterable: TypedIndexList[Index, T], func: Callable[[T], U]
) -> TypedIndexList[Index, U]:
return cast(TypedIndexList[Index, U], [func(item) for item in iterable])
class Array2D(Generic[Row, Col, T]):
__slots__ = ("_data", "_cols")
_data: list[T]
_cols: int
def __init__(self, rows: Row, cols: Col, init_val: Iterable[T]) -> None:
self._data = list(init_val)
self._cols = cols
def __getitem__(self, index: tuple[Row, Col]) -> T:
return self._data[index[0] * self._cols + index[1]]
def __setitem__(self, index: tuple[Row, Col], value: T) -> None:
self._data[index[0] * self._cols + index[1]] = value
@property
def rows(self) -> int:
assert len(self._data) % self._cols == 0
return len(self._data) // self._cols
def row(self, row: Row) -> TypedIndexList[Col, T]:
return cast(TypedIndexList[Col, T], self._data[row * self._cols : (row + 1) * self._cols])
def col(self, col: Col) -> TypedIndexList[Row, T]:
return cast(TypedIndexList[Row, T], self._data[col :: self._cols])
@property
def cols(self) -> int:
return self._cols
def __len__(self) -> int:
return len(self._data)
def __iter__(self) -> Iterator[T]:
return iter(self._data)
def __reversed__(self) -> Iterator[T]:
return reversed(self._data)
def filled_array2d(rows: Row, cols: Col, val: T) -> Array2D[Row, Col, T]:
return Array2D(rows, cols, [val] * rows * cols)
def typed_range(*args: Index) -> Reversible[Index]:
return cast(Reversible[Index], range(*args))
def find(seq: Sequence[T], predicate: Callable[[T], bool], default: U) -> T | U:
return next((item for item in seq if predicate(item)), default)
def find_index(seq: Iterable[T], predicate: Callable[[T], bool]) -> int:
i = 0
for i, item in enumerate(seq):
if predicate(item):
return i
return i + 1
mem_alignment: Final[int] = 2 << 20 # 2MB
_libc = ctypes.CDLL(find_library("c"))
_libc.aligned_alloc.restype = ctypes.c_void_p
_libc.aligned_alloc.argtypes = [ctypes.c_size_t, ctypes.c_size_t]
_libc.madvise.restype = ctypes.c_int
_libc.madvise.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int]
_libc.realloc.restype = ctypes.c_void_p
_libc.realloc.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
_libc.free.restype = None
_libc.free.argtypes = [ctypes.c_void_p]
_libc.posix_fallocate.restype = ctypes.c_int
_libc.posix_fallocate.argtypes = [ctypes.c_int, ctypes.c_longlong, ctypes.c_longlong]
def _aligned_alloc(alignment: int, size: int) -> int:
"""
Allocates size bytes of uninitialized storage whose alignment is specified by alignment.
Returns the address as an integer.
Raises HostOOMError on failure.
"""
assert size % alignment == 0
memptr: ctypes.c_void_p = _libc.aligned_alloc(ctypes.c_size_t(alignment), ctypes.c_size_t(size))
if memptr == ctypes.c_void_p(0):
raise HostOOMError("aligned_alloc failed")
return int(memptr)
def _madvise(ptr: int, size: int, advice: int) -> None:
if os.name == "nt":
return
ret = _libc.madvise(ctypes.c_void_p(ptr), ctypes.c_size_t(size), ctypes.c_int(advice))
if ret != 0:
error_code = ctypes.get_errno()
error_msg = f"madvise failed with errno {error_code}: {errno.errorcode.get(error_code, 'Unknown error')}"
raise HostOOMError(error_msg)
MADV_HUGEPAGE: Final[int] = 14
def _realloc(ptr: int, size: int) -> int:
"""
Reallocates size bytes of storage whose alignment is specified by alignment.
Returns the address as an integer.
Raises OSError on failure.
"""
ret = _libc.realloc(ctypes.c_void_p(ptr), ctypes.c_size_t(size))
if ret == ctypes.c_void_p(0):
raise HostOOMError("realloc failed.")
return int(ret)
def _free(ptr: int) -> None:
_libc.free(ctypes.c_void_p(ptr))
def _posix_fallocate(fd: int, offset: int, length: int) -> None:
ret = _libc.posix_fallocate(
ctypes.c_int(fd), ctypes.c_longlong(offset), ctypes.c_longlong(length)
)
if ret != 0:
raise DiskOOMError(ret, "posix_fallocate failed")
class HostMem:
ALIGNMENT: ClassVar[int] = 2 << 20
"""
Host memory aligned to 2MB, reallocable for low-cost resizing and registered to CUDA as page-locked memory.
Resizing will keep the original memory content, like `realloc` in C.
"""
__slots__ = ("_address", "_size", "_num_registered_chunks")
_address: int
_size: int
# If True and _size > 2GB, use multiple chunks to register pinned memory due to a Linux kernel
# 6.11/6.12/6.13 bug preventing pinning more than 2GB of host memory in one operation.
_CHUNKED_REGISTRATION: ClassVar[bool] = platform.system() == "Linux" and platform.release()[
:4
] in ["6.11", "6.12", "6.13"]
_CHUNK_SIZE: ClassVar[int] = 2 << 30
_num_registered_chunks: int
@property
def address(self) -> int:
return self._address
@property
def size(self) -> int:
return self._size
def __init__(self, size: int) -> None:
self._num_registered_chunks = 0
if size == 0:
self._address = 0
self._size = 0
return
self._address = _aligned_alloc(mem_alignment, size)
self._size = size
_madvise(self._address, size, MADV_HUGEPAGE)
self._register_to_cuda()
def resize(self, new_size: int) -> None:
self._unregister_from_cuda()
try:
self._address = _realloc(self._address, new_size)
self._size = new_size
_madvise(self._address, new_size, MADV_HUGEPAGE)
finally:
self._register_to_cuda()
def destroy(self) -> None:
if self._address == 0:
return
self._unregister_from_cuda()
_free(self._address)
self._address = 0
self._size = 0
def __del__(self) -> None:
self.destroy()
def _register_to_cuda(self) -> None:
assert self._num_registered_chunks == 0
for addr, size in self._iterate_chunks():
_unwrap(
drv.cuMemHostRegister(
addr, size, drv.CU_MEMHOSTREGISTER_PORTABLE | drv.CU_MEMHOSTREGISTER_DEVICEMAP
)
)
self._num_registered_chunks += 1
def _unregister_from_cuda(self) -> None:
for addr, _ in self._iterate_chunks():
if self._num_registered_chunks == 0:
break
_unwrap(drv.cuMemHostUnregister(addr))
self._num_registered_chunks -= 1
assert self._num_registered_chunks == 0
def _iterate_chunks(self) -> Iterator[tuple[int, int]]:
start = self._address
end = start + self._size
chunk_size = self._CHUNK_SIZE if self._CHUNKED_REGISTRATION else self._size
for addr in range(start, end, chunk_size):
yield addr, min(end - addr, chunk_size)
def resize_file(fd: int, new_size: int) -> None:
old_size = os.lseek(fd, 0, os.SEEK_END)
if new_size > old_size:
_posix_fallocate(fd, old_size, new_size - old_size)
elif new_size < old_size:
os.truncate(fd, new_size)
class DynamicBitset:
"""
A memory efficient bitset that can be resized.
"""
__slots__ = ("_bits", "_num_set_bits")
_bits: array.array
_num_set_bits: int
TYPE_CODE: ClassVar[str] = "Q"
ALL_SET_MASK: ClassVar[int] = (1 << 64) - 1
def __init__(self, capacity: int) -> None:
self._bits = array.array(self.TYPE_CODE, [0] * (div_up(capacity, 64)))
self._num_set_bits = 0
def set(self, index: int) -> None:
if not self.get(index):
self._bits[index // 64] |= 1 << (index % 64)
self._num_set_bits += 1
def get(self, index: int) -> bool:
return self._bits[index // 64] & (1 << (index % 64)) != 0
def clear(self, index: int) -> None:
if self.get(index):
self._bits[index // 64] &= ~(1 << (index % 64))
self._num_set_bits -= 1
@property
def num_set_bits(self) -> int:
return self._num_set_bits
def resize(self, new_capacity: int) -> None:
extra_elems = div_up(new_capacity, 64) - len(self._bits)
if extra_elems > 0:
self._bits.extend(array.array(self.TYPE_CODE, [0] * extra_elems))
elif extra_elems < 0:
self._bits = self._bits[:extra_elems]
if new_capacity % 64 != 0:
self._bits[-1] &= self.ALL_SET_MASK >> (64 - (new_capacity % 64))
# check if any bit in the range [start, end) is set
def any_set(self, start: int, end: int) -> bool:
if start >= end:
return False
start_word_mask = self.ALL_SET_MASK << (start % 64)
end_word_mask = self.ALL_SET_MASK >> (64 - (end % 64))
if start // 64 == end // 64:
if (start_word_mask & end_word_mask & self._bits[start // 64]) != 0:
return True
else:
if (start_word_mask & self._bits[start // 64]) != 0 or (
end % 64 != 0 and end_word_mask & self._bits[end // 64]
) != 0:
return True
return any(self._bits[i] != 0 for i in range(start // 64 + 1, end // 64))
@functools.cache
def init_cuda_once() -> None:
(err,) = cudart.cudaFree(0)
assert int(err) == int(cudart.cudaError_t.cudaSuccess)
class SimplePool(Generic[T]):
__slots__ = (
"_create_func",
"_destroy_func",
"_init_size",
"_max_size",
"_outstanding_count",
"_items",
)
_create_func: Callable[[], T]
_destroy_func: Callable[[T], None]
_init_size: int
_max_size: int | None
_items: deque[T] | None
_outstanding_count: (
int # number of items currently we gave out but not returned, i.e. get() but not put()
)
def __init__(
self,
create_func: Callable[[], T],
destroy_func: Callable[[T], None],
init_size: int = 0,
max_size: int | None = None,
):
self._create_func = create_func
self._destroy_func = destroy_func
self._init_size = init_size
self._max_size = max_size
self._items = None
self._outstanding_count = 0
def clear(self) -> None:
while self.items:
self._destroy_func(self.items.popleft())
def __del__(self) -> None:
self.clear()
@property
def items(self) -> deque[T]:
if self._items is None:
self._items = deque[T](
(self._create_func() for _ in range(self._init_size)), maxlen=self._max_size
)
return self._items
def get(self) -> T:
ret = self.items.popleft() if self.items else self._create_func()
self._outstanding_count += 1
return ret
def put(self, item: T) -> None:
self._outstanding_count -= 1
if self._max_size is not None and len(self.items) >= self._max_size:
self._destroy_func(item)
else:
self.items.appendleft(item)
@property
def outstanding_count(self) -> int:
"number of items currently we get() but not put()"
return self._outstanding_count
@property
def cached_count(self) -> int:
"number of items currently in the pool"
return len(self.items)
@property
def total_count(self) -> int:
"total number of items created, including both outstanding and cached"
return self.outstanding_count + self.cached_count
class ItemHolderBase(Generic[T], ABC):
__slots__ = ("_item",)
_item: T | None
def __init__(self) -> None:
self._item = self.pool.get()
def close(self) -> None:
# Manually inlined for better performance.
item = self._item
if item is not None:
self.pool.put(item)
self._item = None
def __del__(self) -> None:
self.close()
def is_closed(self) -> bool:
return self._item is None
def get(self) -> T:
# Manually inlined for better performance.
item = self._item
assert item is not None
return item
@property
def handle(self) -> T:
# Manually inlined for better performance.
item = self._item
assert item is not None
return item
@property
@abstractmethod
def pool(self) -> SimplePool[T]: ...
class CachedCudaEvent(ItemHolderBase[drv.CUevent]):
"""
A cached CUDA event without support for timing. Recorded to a stream when created.
"""
__slots__ = ()
_pool: ClassVar[SimplePool[drv.CUevent] | None] = None
NULL: ClassVar["_NullCudaEvent"]
def __init__(self, stream: CudaStream) -> None:
super().__init__()
self._record(stream)
def query_complete(self) -> bool:
"""
Query the event. If complete, also close the event. Closed events are always considered complete.
"""
# Manually inlined for better performance.
ev = self._item
if ev is None:
return True
(err,) = drv.cuEventQuery(ev)
if int(err) == int(drv.CUresult.CUDA_SUCCESS):
self.close()
return True
elif int(err) == int(drv.CUresult.CUDA_ERROR_NOT_READY):
return False
else:
raise CuError(err)
def synchronize(self) -> None:
# Manually inlined for better performance.
ev = self._item
if ev is None:
return
_unwrap(drv.cuEventSynchronize(ev))
self.close()
def wait_in_stream(self, stream: CudaStream) -> None:
# Manually inlined for better performance.
ev = self._item
if ev is None:
return
_unwrap(drv.cuStreamWaitEvent(stream, ev, 0))
def _record(self, stream: CudaStream) -> None:
"""
Prefer new event instead of recording an existing event.
"""
# Manually inlined for better performance.
ev = self._item
assert ev is not None
_unwrap(drv.cuEventRecord(ev, stream))
@property
def pool(self) -> SimplePool[drv.CUevent]:
if CachedCudaEvent._pool is None:
CachedCudaEvent._pool = SimplePool[drv.CUevent](
lambda: _unwrap(drv.cuEventCreate(drv.CUevent_flags.CU_EVENT_DISABLE_TIMING)),
lambda ev: _unwrap(drv.cuEventDestroy(ev)), # pyright: ignore
init_size=1024,
)
return CachedCudaEvent._pool
class _NullCudaEvent(CachedCudaEvent):
"""
A null CUDA event that is closed (and always complete).
"""
__slots__ = ()
def __init__(self) -> None:
# do not call super().__init__(). We don't need an event here.
self._item = None
CachedCudaEvent.NULL = _NullCudaEvent()
# @TODO: consider do this in a single batch call to C++.
def stream_wait_events(stream: CudaStream, events: Iterable[CachedCudaEvent]) -> None:
"Batched wait for multiple events with deduplication first."
if not isinstance(events, Set):
events = set(events)
for ev in events:
ev.wait_in_stream(stream)
class CachedCudaStream(ItemHolderBase[CudaStream]):
"""
A cached non-blocking CUDA stream.
"""
__slots__ = ()
_pool: ClassVar[SimplePool[CudaStream] | None] = None
def __init__(self) -> None:
super().__init__()
def wait_event(self, event: drv.CUevent) -> None:
_unwrap(drv.cuStreamWaitEvent(self.get(), event, drv.CU_STREAM_WAIT_VALUE_COMPLETED))
def wait_events(self, events: Sequence[CachedCudaEvent] | set[CachedCudaEvent]) -> None:
"""
Wait for events with deduplication first.
"""
stream_wait_events(self.get(), events)
def record_event(self) -> CachedCudaEvent:
return CachedCudaEvent(self.get())
def __cuda_stream__(self) -> tuple[int, int]:
return 0, int(self.get())
@property
def pool(self) -> SimplePool[CudaStream]:
if CachedCudaStream._pool is None:
CachedCudaStream._pool = SimplePool[CudaStream](
lambda: CudaStream(
int(_unwrap(drv.cuStreamCreate(drv.CUstream_flags.CU_STREAM_NON_BLOCKING))) # pyright: ignore
),
lambda stream: _unwrap(drv.cuStreamDestroy(stream)), # pyright: ignore
init_size=128,
)
return CachedCudaStream._pool
class TemporaryCudaStream(CachedCudaStream):
"""
A cached non-blocking CUDA stream. Mainly used as temporary worker streams.
Requires a list of prior events to wait for dependencies. A finish event is recorded when exiting
normally. Call take_finish_event() to get the finish event.
"""
__slots__ = "_finish_event"
_finish_event: CachedCudaEvent | None
def __init__(self, prior_events: Sequence[CachedCudaEvent] | set[CachedCudaEvent]):
super().__init__()
self.wait_events(prior_events)
self._finish_event = None
def __del__(self) -> None:
if self._finish_event is not None:
warnings.warn("[KVCacheManager] finish event recorded but not taken")
super().__del__()
def take_finish_event(self) -> CachedCudaEvent:
ret = unwrap_optional(self._finish_event)
self._finish_event = None
return ret
def __enter__(self) -> "TemporaryCudaStream":
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
if not exc_type:
self._finish_event = self.record_event()
def merge_events(events: Sequence[CachedCudaEvent] | set[CachedCudaEvent]) -> CachedCudaEvent:
if len(events) == 0:
return CachedCudaEvent.NULL
if len(events) == 1:
ev = next(iter(events))
return ev if not ev.is_closed() else CachedCudaEvent.NULL
with TemporaryCudaStream(events) as stream:
pass
return stream.take_finish_event()
class MultiStreamExecutor:
__slots__ = ("_prior_event", "_streams", "_finish_event")
_prior_event: CachedCudaEvent
_streams: list[TemporaryCudaStream]
_finish_event: CachedCudaEvent | None
def __init__(self, prior_events: Sequence[CachedCudaEvent] | set[CachedCudaEvent]):
self._prior_event = merge_events(prior_events)
self._streams = []
self._finish_event = None
def __enter__(self) -> "MultiStreamExecutor":
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
events = [s.take_finish_event() for s in self._streams]
self._streams.clear()
self._finish_event = merge_events(events)
def __del__(self) -> None:
assert_critical(self._finish_event is None, "finish event not taken")
def new_stream(self) -> TemporaryCudaStream:
stream = TemporaryCudaStream((self._prior_event,))
self._streams.append(stream)
return stream
def take_finish_event(self) -> CachedCudaEvent:
ret = unwrap_optional(self._finish_event)
self._finish_event = None
return ret
class SharedPoolProvider(Generic[T]):
_pool: SimplePool[T]
def __init__(self, pool: SimplePool[T]):
self._pool = pool
def pool(self) -> SimplePool[T]:
return self._pool
class ItemHolderWithSharedPool(ItemHolderBase[T]):
__slots__ = ("_pool",)
_pool: SimplePool[T]
def __init__(self, pool: SimplePool[T]) -> None:
self._pool = pool
super().__init__()
def __del__(self) -> None:
self.close()
@property
def pool(self) -> SimplePool[T]:
return self._pool
HolderT = TypeVar("HolderT", bound=ItemHolderWithSharedPool)
# For subclassing if holder needs to be customized
class PooledFactoryBase(Generic[T, HolderT]):
_Holder: Type[HolderT] # subclasses must initialize this static attribute
__slots__ = ("_pool",)
_pool: SimplePool[T]
def __init__(
self,
create_func: Callable[[], T],
destroy_func: Callable[[T], None],
init_size: int = 0,
max_cache_size: int | None = None,
):
self._pool = SimplePool[T](create_func, destroy_func, init_size, max_cache_size)
def create(self) -> HolderT:
return self._Holder(self._pool)
def clear(self) -> None:
self._pool.clear()
def query_total_gpu_memory() -> int:
_, total = _unwrap(drv.cuMemGetInfo()) # pyright: ignore
return total
def query_free_gpu_memory() -> int:
free, _ = _unwrap(drv.cuMemGetInfo()) # pyright: ignore
return free
class CudaStreamWrapper:
"Just a wrapper to make it compatible with IsStreamT protocol. Does not own the stream."
__slots__ = ("_stream",)
_stream: CudaStream
def __init__(self, stream: CudaStream) -> None:
self._stream = stream
def __cuda_stream__(self) -> tuple[int, int]:
return 0, int(self._stream)

View File

@ -0,0 +1,51 @@
; SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
; SPDX-License-Identifier: Apache-2.0
;
; Licensed under the Apache License, Version 2.0 (the "License");
; you may not use this file except in compliance with the License.
; You may obtain a copy of the License at
;
; http://www.apache.org/licenses/LICENSE-2.0
;
; Unless required by applicable law or agreed to in writing, software
; distributed under the License is distributed on an "AS IS" BASIS,
; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
; See the License for the specific language governing permissions and
; limitations under the License.
[mypy]
# Only check files explicitly listed - don't follow any imports
follow_imports = skip
follow_imports_for_stubs = False
# Ignore all missing imports
ignore_missing_imports = True
# Allow untyped code in dependencies
allow_untyped_calls = True
allow_untyped_defs = True
check_untyped_defs = False
# Disable various warnings to reduce noise
warn_return_any = False
warn_unused_ignores = False
warn_unreachable = False
no_implicit_optional = False
# Don't check .pyi files outside our target
exclude = (?x)(
^(?!tensorrt_llm/runtime/kv_cache_manager_v2/)
)
# Ignore errors in any imported modules
[mypy-tensorrt_llm.executor.*]
ignore_errors = True
follow_imports = skip
[mypy-tensorrt_llm.bindings.*]
ignore_errors = True
follow_imports = skip
[mypy-torch.*]
ignore_errors = True
follow_imports = skip

View File

@ -0,0 +1,140 @@
<!--
SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# rawref - Mutable Reference C Extension
A C extension that provides a mutable reference class similar to `weakref.ref` for holding weak-like references to Python objects.
## Features
- **`ref[T]`**: A generic reference class (like `weakref.ref`) that stores an object's ID
- **Singleton pattern**: `ref(obj)` returns the same reference if `obj.__rawref__` is valid
- **Dereferencing**: Call `r()` to get the object, or `None` if invalid
- **Invalidation**: Call `r.invalidate()` to mark the reference as invalid
- **NULL constant**: Use `NULL` to initialize `__rawref__` attributes
- **Type-safe**: Comes with `.pyi` stub file for proper type checking
- **API compatible with weakref**: Use `ref` for both object creation and type hints
## Building
From the `rawref` directory:
```bash
python setup.py build_ext --inplace
```
Or install it:
```bash
pip install -e .
```
## Usage
```python
from rawref import ref, NULL
class MyClass:
# Class attribute: default value for __rawref__
# Each instance will get its own __rawref__ instance attribute when ref() is called
__rawref__ = NULL
def __init__(self, value):
self.value = value
def __del__(self):
# self.__rawref__ is an instance attribute (set by ref())
# Invalidate the canonical reference when object is destroyed
if self.__rawref__.is_valid:
self.__rawref__.invalidate()
# Create an object and a reference to it (just like weakref.ref)
obj = MyClass(42)
r1 = ref(obj)
# The reference is automatically stored as an instance attribute obj.__rawref__
print(obj.__rawref__ is r1) # True
# Singleton pattern: creating another ref returns the same one
r2 = ref(obj)
print(r1 is r2) # True
# Dereference to get the object back
print(r1()) # <MyClass instance>
print(r1().value) # 42
# Check validity
print(r1.is_valid) # True
# After invalidation
r1.invalidate()
print(r1()) # None
print(r1.is_valid) # False
# Creating a new ref after invalidation creates a new reference
r3 = ref(obj)
print(r1 is r3) # False
print(r3.is_valid) # True
```
## Type Hints
Like `weakref.ref`, you can use `ref` for both object creation and type hints:
```python
from rawref import ref, NULL
class MyClass:
__rawref__ = NULL
# Create and type a reference
r: ref[MyClass] = ref(MyClass())
# Alternative: use ReferenceType directly
from rawref import ReferenceType
r: ReferenceType[MyClass] = ReferenceType(MyClass())
```
## Warning
This implementation uses raw object IDs (memory addresses) and attempts to dereference them. This is inherently unsafe and should be used with caution. The reference does not keep the object alive (unlike a strong reference), so care must be taken to ensure the object is not garbage collected while references exist.
## API
### Classes and Constants
- `ReferenceType`: The main reference class
- `ref`: Alias for `ReferenceType` (like `weakref.ref`)
- `NULL`: An invalid reference constant for initialization
### Creation
- `ref(obj)`: Create a reference to `obj`, or return existing valid reference from `obj.__rawref__`
### Properties
- `r.is_valid`: Check if the reference is still valid (read-only)
### Methods
- `r()`: Dereference to get the object, or `None` if invalid
- `r.invalidate()`: Mark the reference as invalid
## Singleton Pattern
The `ref()` function implements a singleton pattern:
1. When `ref(obj)` is called, it checks if `obj.__rawref__` (instance attribute) exists and is valid
2. If yes, it returns the existing reference
3. If no, it creates a new reference and sets `obj.__rawref__` as an instance attribute
**Note**: The class attribute `__rawref__ = NULL` is just a default value. When `ref(obj)` is called, it creates an **instance attribute** `obj.__rawref__` that shadows the class attribute. Each instance gets its own `__rawref__` instance attribute, ensuring each object has at most one canonical reference at a time.

View File

@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""rawref - Mutable reference with singleton pattern.
This module provides a C extension for creating mutable references to Python
objects, similar to weakref.ref but with manual invalidation control and a
singleton pattern via __rawref__.
The main purpose is to work around the issue that mypyc does not support
weakref.
Main exports:
- ReferenceType: The reference class
- ref: Alias for ReferenceType (recommended, like weakref.ref)
- NULL: Invalid reference constant for initialization
"""
from ._rawref import NULL, ReferenceType, ref
__all__ = ["ReferenceType", "ref", "NULL"]
__version__ = "2.0.0"

View File

@ -0,0 +1,74 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Generic, Optional, TypeVar
T = TypeVar("T")
class ReferenceType(Generic[T]):
"""A mutable reference holder that stores an object ID.
This class holds a reference to an object via its ID and allows
dereferencing it. The reference can be invalidated.
Like weakref.ref, but stores raw object IDs instead of proper weak references.
Implements a singleton pattern: calling ref(obj) multiple times returns the
same reference if obj.__rawref__ exists and is valid.
"""
@property
def is_valid(self) -> bool:
"""Check if the reference is still valid (read-only)."""
...
def __init__(self, obj: T) -> None:
"""Initialize a ReferenceType with an object.
If obj.__rawref__ exists and is valid, returns that instead.
Otherwise creates a new reference and sets obj.__rawref__ to it.
Args:
obj: The object to reference.
"""
...
def __call__(self) -> Optional[T]:
"""Dereference the object.
Returns:
The referenced object, or None if the reference is invalid.
"""
...
def invalidate(self) -> None:
"""Invalidate the reference.
After calling this method, __call__() will return None.
This should be called from T.__del__ to invalidate the reference.
"""
...
# Alias 'ref' to 'ReferenceType' (like weakref.ref is an alias to weakref.ReferenceType)
ref = ReferenceType
# NULL is an invalid reference constant that can be used to initialize __rawref__
NULL: ReferenceType
# For type hints, you can use either:
# r: ref[MyClass] = ref(obj)
# or:
# r: ReferenceType[MyClass] = ReferenceType(obj)
# Both are equivalent.

View File

@ -0,0 +1,238 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <structmember.h>
/* ReferenceType object structure */
typedef struct
{
PyObject_HEAD Py_ssize_t object_id; /* ID of the referenced object */
int valid; /* 1 if valid, 0 if invalidated */
} ReferenceTypeObject;
/* Forward declarations */
static PyTypeObject ReferenceTypeType;
/* Cached attribute name for faster lookups */
static PyObject* rawref_attr_name = NULL;
/* ReferenceType.__new__ - implements singleton pattern via __rawref__ */
static PyObject* ReferenceType_new(PyTypeObject* type, PyObject* args, PyObject* kwds)
{
PyObject* obj = NULL;
static char* kwlist[] = {"obj", NULL};
/* Parse arguments to get the object */
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &obj))
{
return NULL;
}
/* Try to get existing __rawref__ using cached attribute name (faster) */
PyObject* existing_ref = PyObject_GetAttr(obj, rawref_attr_name);
if (existing_ref != NULL)
{
/* Check if it's a ReferenceType instance and is valid */
if (PyObject_TypeCheck(existing_ref, &ReferenceTypeType))
{
ReferenceTypeObject* ref_obj = (ReferenceTypeObject*) existing_ref;
if (ref_obj->valid)
{
/* Return existing valid reference */
return existing_ref;
}
}
Py_DECREF(existing_ref);
}
else
{
/* Clear the AttributeError if __rawref__ doesn't exist */
PyErr_Clear();
}
/* Create new reference */
ReferenceTypeObject* self;
self = (ReferenceTypeObject*) type->tp_alloc(type, 0);
if (self != NULL)
{
self->object_id = (Py_ssize_t) obj;
self->valid = 1;
/* Set obj.__rawref__ to this new reference using cached attr name */
if (PyObject_SetAttr(obj, rawref_attr_name, (PyObject*) self) < 0)
{
/* If we can't set __rawref__, just clear the error and continue */
PyErr_Clear();
}
}
return (PyObject*) self;
}
/* ReferenceType.__init__ */
static int ReferenceType_init(ReferenceTypeObject* self, PyObject* args, PyObject* kwds)
{
/* __new__ already did all the work, including setting object_id and valid */
/* Skip argument parsing since __new__ already validated them */
/* This saves ~5-10% overhead on object creation */
return 0;
}
/* ReferenceType.__call__() - dereference the object */
static PyObject* ReferenceType_call(ReferenceTypeObject* self, PyObject* args, PyObject* kwds)
{
PyObject* obj;
if (!self->valid)
{
Py_RETURN_NONE;
}
/* Use _PyObject_FromStackRefSteal or ctypes approach */
/* We need to find the object by its id */
/* This is the tricky part - we need to convert id back to object */
/* Use ctypes.cast to convert id to PyObject* */
obj = (PyObject*) self->object_id;
/* Check if the object is still alive by verifying ref count > 0 */
/* This is somewhat unsafe but matches the intended behavior */
if (Py_REFCNT(obj) > 0)
{
Py_INCREF(obj);
return obj;
}
/* Object no longer valid */
self->valid = 0;
Py_RETURN_NONE;
}
/* ReferenceType.invalidate() */
static PyObject* ReferenceType_invalidate(ReferenceTypeObject* self, PyObject* Py_UNUSED(ignored))
{
self->valid = 0;
Py_RETURN_NONE;
}
/* ReferenceType.is_valid property getter */
static PyObject* ReferenceType_is_valid(ReferenceTypeObject* self, void* closure)
{
return PyBool_FromLong(self->valid);
}
/* ReferenceType.__class_getitem__() - support for generic type subscripting */
static PyObject* ReferenceType_class_getitem(PyObject* cls, PyObject* item)
{
/* Just return the class itself, ignore the type parameter */
/* This allows rawref.ref[T] to work at runtime like weakref.ref[T] */
Py_INCREF(cls);
return cls;
}
/* Method definitions */
static PyMethodDef ReferenceType_methods[] = {
{"invalidate", (PyCFunction) ReferenceType_invalidate, METH_NOARGS,
"Invalidate the reference, making it return None on dereference."},
{"__class_getitem__", (PyCFunction) ReferenceType_class_getitem, METH_O | METH_CLASS,
"Support for generic type subscripting (e.g., ref[T])."},
{NULL} /* Sentinel */
};
/* Property definitions */
static PyGetSetDef ReferenceType_getsetters[] = {
{"is_valid", (getter) ReferenceType_is_valid, NULL, "Check if the reference is still valid (read-only).", NULL},
{NULL} /* Sentinel */
};
/* Type definition */
static PyTypeObject ReferenceTypeType = {
PyVarObject_HEAD_INIT(NULL, 0).tp_name = "_rawref.ReferenceType",
.tp_doc = "A mutable reference holder that stores an object ID.",
.tp_basicsize = sizeof(ReferenceTypeObject),
.tp_itemsize = 0,
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
.tp_new = ReferenceType_new,
.tp_init = (initproc) ReferenceType_init,
.tp_call = (ternaryfunc) ReferenceType_call,
.tp_methods = ReferenceType_methods,
.tp_getset = ReferenceType_getsetters,
};
/* Module definition */
static PyModuleDef rawrefmodule = {
PyModuleDef_HEAD_INIT,
.m_name = "_rawref",
.m_doc = "C extension providing mutable reference class ReferenceType (internal module).",
.m_size = -1,
};
/* Module initialization */
PyMODINIT_FUNC PyInit__rawref(void)
{
PyObject* m;
ReferenceTypeObject* null_ref;
if (PyType_Ready(&ReferenceTypeType) < 0)
return NULL;
m = PyModule_Create(&rawrefmodule);
if (m == NULL)
return NULL;
/* Cache the __rawref__ attribute name for faster lookups */
rawref_attr_name = PyUnicode_InternFromString("__rawref__");
if (rawref_attr_name == NULL)
{
Py_DECREF(m);
return NULL;
}
Py_INCREF(&ReferenceTypeType);
if (PyModule_AddObject(m, "ReferenceType", (PyObject*) &ReferenceTypeType) < 0)
{
Py_DECREF(&ReferenceTypeType);
Py_DECREF(m);
return NULL;
}
/* Add 'ref' as an alias for 'ReferenceType' (like weakref.ref) */
Py_INCREF(&ReferenceTypeType);
if (PyModule_AddObject(m, "ref", (PyObject*) &ReferenceTypeType) < 0)
{
Py_DECREF(&ReferenceTypeType);
Py_DECREF(m);
return NULL;
}
/* Create NULL constant - an invalid reference */
null_ref = (ReferenceTypeObject*) ReferenceTypeType.tp_alloc(&ReferenceTypeType, 0);
if (null_ref != NULL)
{
null_ref->object_id = 0;
null_ref->valid = 0;
if (PyModule_AddObject(m, "NULL", (PyObject*) null_ref) < 0)
{
Py_DECREF(null_ref);
Py_DECREF(m);
return NULL;
}
}
return m;
}

View File

@ -0,0 +1,28 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup, Extension
rawref_module = Extension(
'_rawref',
sources=['rawrefmodule.c'],
)
setup(
name='rawref',
version='1.0',
description='C extension providing mutable reference class Ref[T]',
ext_modules=[rawref_module],
)

View File

@ -0,0 +1,269 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script for the rawref module."""
import os
import sys
# Add parent directory to path to import the rawref package
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from rawref import NULL, ReferenceType, ref
except ImportError as e:
print(f"Error importing rawref: {e}")
print("Make sure to build the extension first with: python setup.py build_ext --inplace")
sys.exit(1)
class TestObject:
"""Test class with __del__ that invalidates references."""
def __init__(self, value):
self.value = value
self.refs = [] # Store references to invalidate
def __del__(self):
print(f"TestObject({self.value}).__del__ called, invalidating {len(self.refs)} references")
for ref in self.refs:
ref.invalidate()
def test_basic_reference():
"""Test basic reference creation and dereferencing."""
print("\n=== Test 1: Basic Reference ===")
obj = TestObject(42)
r = ref(obj)
obj.refs.append(r)
print(f"Created object with value: {obj.value}")
print(f"Reference is_valid: {r.is_valid}")
# Dereference
dereferenced = r()
print(f"Dereferenced object: {dereferenced}")
if dereferenced:
print(f"Dereferenced value: {dereferenced.value}")
assert r.is_valid, "Reference should be valid"
assert dereferenced is obj, "Dereferenced object should be the same as original"
def test_invalidation():
"""Test manual invalidation."""
print("\n=== Test 2: Manual Invalidation ===")
obj = TestObject(123)
r = ref(obj)
print(f"Before invalidation - is_valid: {r.is_valid}")
print(f"Dereferenced: {r()}")
r.invalidate()
print(f"After invalidation - is_valid: {r.is_valid}")
print(f"Dereferenced: {r()}")
assert not r.is_valid, "Reference should be invalid after invalidate()"
assert r() is None, "Dereferencing invalid reference should return None"
def test_del_invalidation():
"""Test invalidation via __del__."""
print("\n=== Test 3: Invalidation via __del__ ===")
r = None
# Create object in a scope
def create_and_ref():
obj = TestObject(999)
nonlocal r
r = ref(obj)
obj.refs.append(r)
print(f"Inside scope - is_valid: {r.is_valid}")
print(f"Inside scope - dereferenced: {r()}")
create_and_ref()
# Object should be deleted and reference invalidated
print(f"After scope - is_valid: {r.is_valid}")
print(f"After scope - dereferenced: {r()}")
assert not r.is_valid, "Reference should be invalid after object deletion"
assert r() is None, "Dereferencing should return None after object deletion"
def test_multiple_references():
"""Test that singleton pattern returns same reference."""
print("\n=== Test 4: Singleton Pattern - Same Reference ===")
obj = TestObject(555)
r1 = ref(obj)
r2 = ref(obj)
obj.refs.extend([r1, r2])
print(f"r1 is r2: {r1 is r2}")
assert r1 is r2, "With singleton pattern, should return the same reference"
# Invalidate r1 (which is the same as r2)
r1.invalidate()
print("After invalidating r1:")
print(f" r1 is_valid: {r1.is_valid}, dereferenced: {r1()}")
print(f" r2 is_valid: {r2.is_valid}, dereferenced: {r2()}")
# Since they're the same object, both are invalid
assert not r1.is_valid, "r1 should be invalid"
assert not r2.is_valid, "r2 should also be invalid (same object)"
def test_alias_equivalence():
"""Test that ref and ReferenceType are the same."""
print("\n=== Test 5: ref and ReferenceType are equivalent ===")
obj = TestObject(777)
r1 = ref(obj)
r2 = ReferenceType(obj)
print(f"ref is ReferenceType: {ref is ReferenceType}")
print(f"type(r1): {type(r1)}")
print(f"type(r2): {type(r2)}")
print(f"r1 is r2: {r1 is r2}")
assert ref is ReferenceType, "ref should be an alias for ReferenceType"
assert r1 is r2, "Should return the same reference object (singleton pattern)"
def test_singleton_pattern():
"""Test that ref(obj) returns the same reference if __rawref__ is valid."""
print("\n=== Test 6: Singleton Pattern ===")
obj = TestObject(888)
# First call creates a new reference
r1 = ref(obj)
print(f"First ref: {r1}, valid: {r1.is_valid}")
# Second call returns the same reference
r2 = ref(obj)
print(f"Second ref: {r2}, valid: {r2.is_valid}")
print(f"r1 is r2: {r1 is r2}")
assert r1 is r2, "Should return the same reference object"
# After invalidation, a new call creates a new reference
r1.invalidate()
print(f"After invalidation: r1.valid={r1.is_valid}, r2.valid={r2.is_valid}")
r3 = ref(obj)
print(f"Third ref (after invalidation): {r3}, valid: {r3.is_valid}")
print(f"r1 is r3: {r1 is r3}")
assert r1 is not r3, "Should create a new reference after invalidation"
assert r3.is_valid, "New reference should be valid"
def test_null_constant():
"""Test the NULL constant."""
print("\n=== Test 7: NULL Constant ===")
print(f"NULL: {NULL}")
print(f"NULL.is_valid: {NULL.is_valid}")
print(f"NULL(): {NULL()}")
print(f"type(NULL): {type(NULL)}")
assert not NULL.is_valid, "NULL should be invalid"
assert NULL() is None, "NULL() should return None"
# Test using NULL to initialize __rawref__
class MyClass:
__rawref__ = NULL
obj = MyClass()
r = ref(obj)
print("After creating ref for obj with __rawref__=NULL:")
print(f" obj.__rawref__ is r: {obj.__rawref__ is r}")
print(f" r.is_valid: {r.is_valid}")
assert obj.__rawref__ is r, "Should update __rawref__ to new reference"
assert r.is_valid, "New reference should be valid"
def test_hidden_object_id():
"""Test that object_id is hidden."""
print("\n=== Test 8: Hidden object_id ===")
obj = TestObject(123)
r = ref(obj)
# Try to access object_id - should raise AttributeError
try:
_ = r.object_id
assert False, "Should not be able to access object_id"
except AttributeError:
print("✓ object_id is hidden (AttributeError raised)")
print()
def test_is_valid_readonly():
"""Test that is_valid is read-only."""
print("\n=== Test 9: is_valid is read-only ===")
obj = TestObject(456)
r = ref(obj)
print(f"r.is_valid: {r.is_valid}")
# Try to set is_valid - should raise AttributeError
try:
r.is_valid = False
assert False, "Should not be able to set is_valid"
except AttributeError:
print("✓ is_valid is read-only (AttributeError raised)")
print()
if __name__ == "__main__":
print("Testing rawref module...")
try:
test_basic_reference()
test_invalidation()
test_del_invalidation()
test_multiple_references()
test_alias_equivalence()
test_singleton_pattern()
test_null_constant()
test_hidden_object_id()
test_is_valid_readonly()
print("\n" + "=" * 50)
print("All tests passed!")
print("=" * 50)
except AssertionError as e:
print(f"\n❌ Test failed: {e}")
sys.exit(1)
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@ -0,0 +1,138 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup script for compiling kv_cache_manager_v2 with mypyc.
Usage (from project root):
python tensorrt_llm/runtime/kv_cache_manager_v2/setup_mypyc.py build_ext --inplace
Or use the build script:
./tensorrt_llm/runtime/kv_cache_manager_v2/build_mypyc.sh
"""
import os
import sys
from mypyc.build import mypycify
from setuptools import setup
# Set environment variables BEFORE importing mypyc
os.environ["MYPY_FORCE_COLOR"] = "0"
# Write a strict mypy config that won't check external files
mypy_config_path = os.path.abspath("mypy_mypyc_build.ini")
with open(mypy_config_path, "w") as f:
f.write("""[mypy]
# Critical: Don't follow any imports outside the specified files
follow_imports = skip
follow_imports_for_stubs = False
# Ignore missing imports completely
ignore_missing_imports = True
# Allow all untyped code
allow_untyped_calls = True
allow_untyped_defs = True
allow_incomplete_defs = True
allow_untyped_globals = True
check_untyped_defs = False
# Disable all warnings that might cause errors
disallow_untyped_calls = False
disallow_untyped_defs = False
disallow_incomplete_defs = False
warn_return_any = False
warn_unused_ignores = False
# Disable type validation errors (for external types like drv.CUstream)
disable_error_code = valid-type
""")
# Point mypy to this config by adding to sys.argv before mypyc runs
sys.argv.extend(["--config-file", mypy_config_path])
# List all Python modules in kv_cache_manager_v2 to compile
#
# EXCLUDED FILES:
# - _exceptions.py: inherits from builtin Exception classes (mypyc limitation)
#
modules = [
# Main module files
"kv_cache_manager_v2/__init__.py",
"kv_cache_manager_v2/_block_radix_tree.py",
"kv_cache_manager_v2/_common.py",
"kv_cache_manager_v2/_config.py",
"kv_cache_manager_v2/_copy_engine.py",
"kv_cache_manager_v2/_cuda_virt_mem.py",
"kv_cache_manager_v2/_exceptions.py",
"kv_cache_manager_v2/_life_cycle_registry.py",
"kv_cache_manager_v2/_page.py",
"kv_cache_manager_v2/_storage_manager.py",
"kv_cache_manager_v2/_utils.py",
# _core submodule
"kv_cache_manager_v2/_core/__init__.py",
"kv_cache_manager_v2/_core/_kv_cache_manager.py",
"kv_cache_manager_v2/_core/_kv_cache.py",
# _eviction_controller submodule
"kv_cache_manager_v2/_eviction_controller/__init__.py",
"kv_cache_manager_v2/_eviction_controller/_eviction_controller.py",
# _storage submodule
"kv_cache_manager_v2/_storage/__init__.py",
"kv_cache_manager_v2/_storage/_config.py",
"kv_cache_manager_v2/_storage/_core.py",
]
print(f"Compiling {len(modules)} modules with mypyc...")
print("Excluded: None")
print("")
try:
ext_modules = mypycify(
modules,
opt_level="3", # Maximum optimization
multi_file=True, # Allow cross-module references (needed for inheritance)
verbose=True, # Show what's being compiled
separate=False, # Compile into single .so (required for cross-module inheritance)
strip_asserts=False, # Keep assertions for debugging
)
except Exception as e:
print(f"Error during mypyc compilation: {e}")
sys.exit(1)
finally:
# Cleanup temp config
if os.path.exists(mypy_config_path):
try:
os.remove(mypy_config_path)
except OSError:
pass
# Remove --config-file arguments from sys.argv before calling setup()
# This prevents setuptools from seeing arguments it doesn't understand
while "--config-file" in sys.argv:
idx = sys.argv.index("--config-file")
sys.argv.pop(idx) # Remove '--config-file'
if idx < len(sys.argv): # Remove the path that follows it
sys.argv.pop(idx)
setup(
name="kv_cache_manager_v2_compiled",
ext_modules=ext_modules,
packages=["kv_cache_manager_v2.rawref"],
package_data={
"kv_cache_manager_v2": ["*.pyi", "**/*.pyi"],
},
python_requires=">=3.8",
)

View File

@ -12,6 +12,7 @@ unittest/trt/attention/test_gpt_attention_IFB.py,NVIDIA A10,17,
unittest/trt/attention/test_gpt_attention_no_cache.py,NVIDIA A10,23,
unittest/trt/model/test_mamba.py,NVIDIA A10,12,
unittest/trt/model/test_llama.py,NVIDIA A10,3,
unittest/kv_cache_manager_v2_tests/,NVIDIA A10,8,
"unittest/trt/attention/test_gpt_attention.py -k ""partition0""",NVIDIA A10,14,
"unittest/trt/attention/test_gpt_attention.py -k ""partition1""",NVIDIA A10,10,
"unittest/trt/attention/test_gpt_attention.py -k ""partition2""",NVIDIA A10,3,
@ -42,6 +43,7 @@ unittest/trt/quantization/test_weight_only_groupwise_quant_matmul.py,NVIDIA H100
unittest/trt/attention/test_gpt_attention_IFB.py,NVIDIA H100 80GB HBM3,11,
unittest/trt/attention/test_gpt_attention_no_cache.py,NVIDIA H100 80GB HBM3,13,
unittest/trt/model/test_mamba.py,NVIDIA H100 80GB HBM3,10,
unittest/kv_cache_manager_v2_tests/,NVIDIA H100 80GB HBM3,8,
"unittest/trt/attention/test_gpt_attention.py -k ""partition0""",NVIDIA L40S,14,
"unittest/trt/attention/test_gpt_attention.py -k ""partition1""",NVIDIA L40S,10,
"unittest/trt/attention/test_gpt_attention.py -k ""partition2""",NVIDIA L40S,6,
@ -64,6 +66,7 @@ unittest/trt/quantization/test_weight_only_groupwise_quant_matmul.py,NVIDIA H100
unittest/trt/attention/test_gpt_attention_IFB.py,NVIDIA H100 PCIe,11,
unittest/trt/attention/test_gpt_attention_no_cache.py,NVIDIA H100 PCIe,13,
unittest/trt/model/test_mamba.py,NVIDIA H100 PCIe,10,
unittest/kv_cache_manager_v2_tests/,NVIDIA H100 PCIe,8,
llmapi-tp-2gpu,NVIDIA H100 NVL,1,
unittest/llmapi/test_llm_models_multi_gpu.py,NVIDIA H100 NVL,1,
unittest/trt/model/test_gptneox.py,NVIDIA H100 NVL,7,
@ -80,6 +83,7 @@ unittest/trt/quantization/test_weight_only_groupwise_quant_matmul.py,NVIDIA H100
unittest/trt/attention/test_gpt_attention_IFB.py,NVIDIA H100 NVL,11,
unittest/trt/attention/test_gpt_attention_no_cache.py,NVIDIA H100 NVL,13,
unittest/trt/model/test_mamba.py,NVIDIA H100 NVL,10,
unittest/kv_cache_manager_v2_tests/,NVIDIA H100 NVL,8,
llmapi-tp-2gpu,NVIDIA H100,1,
unittest/llmapi/test_llm_models_multi_gpu.py,NVIDIA H100,1,
unittest/trt/model/test_gptneox.py,NVIDIA H100,7,
@ -96,6 +100,7 @@ unittest/trt/quantization/test_weight_only_groupwise_quant_matmul.py,NVIDIA H100
unittest/trt/attention/test_gpt_attention_IFB.py,NVIDIA H100,11,
unittest/trt/attention/test_gpt_attention_no_cache.py,NVIDIA H100,13,
unittest/trt/model/test_mamba.py,NVIDIA H100,10,
unittest/kv_cache_manager_v2_tests/,NVIDIA H100,8,
"unittest/trt/attention/test_gpt_attention.py -k ""partition0""",NVIDIA L40,14,
"unittest/trt/attention/test_gpt_attention.py -k ""partition1""",NVIDIA L40,10,
"unittest/trt/attention/test_gpt_attention.py -k ""partition2""",NVIDIA L40,6,
@ -111,6 +116,7 @@ unittest/_torch/misc,NVIDIA B200,4,
unittest/_torch/speculative,NVIDIA B200,4,
unittest/_torch/thop/parallel,NVIDIA B200,16,
"unittest/_torch/auto_deploy/unit/singlegpu -k ""not test_trtllm_bench_backend_comparison""",NVIDIA B200,4,
unittest/kv_cache_manager_v2_tests/,NVIDIA B200,8,
unittest/_torch/attention,NVIDIA H100,4,
unittest/_torch/misc,NVIDIA H100,4,
unittest/_torch/thop/parallel,NVIDIA H100,16,

1 unittest_case_name gpu parallel_factor comment
12 unittest/trt/attention/test_gpt_attention_no_cache.py NVIDIA A10 23
13 unittest/trt/model/test_mamba.py NVIDIA A10 12
14 unittest/trt/model/test_llama.py NVIDIA A10 3
15 unittest/kv_cache_manager_v2_tests/ NVIDIA A10 8
16 unittest/trt/attention/test_gpt_attention.py -k "partition0" NVIDIA A10 14
17 unittest/trt/attention/test_gpt_attention.py -k "partition1" NVIDIA A10 10
18 unittest/trt/attention/test_gpt_attention.py -k "partition2" NVIDIA A10 3
43 unittest/trt/attention/test_gpt_attention_IFB.py NVIDIA H100 80GB HBM3 11
44 unittest/trt/attention/test_gpt_attention_no_cache.py NVIDIA H100 80GB HBM3 13
45 unittest/trt/model/test_mamba.py NVIDIA H100 80GB HBM3 10
46 unittest/kv_cache_manager_v2_tests/ NVIDIA H100 80GB HBM3 8
47 unittest/trt/attention/test_gpt_attention.py -k "partition0" NVIDIA L40S 14
48 unittest/trt/attention/test_gpt_attention.py -k "partition1" NVIDIA L40S 10
49 unittest/trt/attention/test_gpt_attention.py -k "partition2" NVIDIA L40S 6
66 unittest/trt/attention/test_gpt_attention_IFB.py NVIDIA H100 PCIe 11
67 unittest/trt/attention/test_gpt_attention_no_cache.py NVIDIA H100 PCIe 13
68 unittest/trt/model/test_mamba.py NVIDIA H100 PCIe 10
69 unittest/kv_cache_manager_v2_tests/ NVIDIA H100 PCIe 8
70 llmapi-tp-2gpu NVIDIA H100 NVL 1
71 unittest/llmapi/test_llm_models_multi_gpu.py NVIDIA H100 NVL 1
72 unittest/trt/model/test_gptneox.py NVIDIA H100 NVL 7
83 unittest/trt/attention/test_gpt_attention_IFB.py NVIDIA H100 NVL 11
84 unittest/trt/attention/test_gpt_attention_no_cache.py NVIDIA H100 NVL 13
85 unittest/trt/model/test_mamba.py NVIDIA H100 NVL 10
86 unittest/kv_cache_manager_v2_tests/ NVIDIA H100 NVL 8
87 llmapi-tp-2gpu NVIDIA H100 1
88 unittest/llmapi/test_llm_models_multi_gpu.py NVIDIA H100 1
89 unittest/trt/model/test_gptneox.py NVIDIA H100 7
100 unittest/trt/attention/test_gpt_attention_IFB.py NVIDIA H100 11
101 unittest/trt/attention/test_gpt_attention_no_cache.py NVIDIA H100 13
102 unittest/trt/model/test_mamba.py NVIDIA H100 10
103 unittest/kv_cache_manager_v2_tests/ NVIDIA H100 8
104 unittest/trt/attention/test_gpt_attention.py -k "partition0" NVIDIA L40 14
105 unittest/trt/attention/test_gpt_attention.py -k "partition1" NVIDIA L40 10
106 unittest/trt/attention/test_gpt_attention.py -k "partition2" NVIDIA L40 6
116 unittest/_torch/speculative NVIDIA B200 4
117 unittest/_torch/thop/parallel NVIDIA B200 16
118 unittest/_torch/auto_deploy/unit/singlegpu -k "not test_trtllm_bench_backend_comparison" NVIDIA B200 4
119 unittest/kv_cache_manager_v2_tests/ NVIDIA B200 8
120 unittest/_torch/attention NVIDIA H100 4
121 unittest/_torch/misc NVIDIA H100 4
122 unittest/_torch/thop/parallel NVIDIA H100 16

View File

@ -185,6 +185,7 @@ l0_a10:
- unittest/trt/attention/test_gpt_attention_IFB.py
- unittest/trt/attention/test_gpt_attention_no_cache.py
- examples/test_whisper.py::test_llm_whisper_general[large-v3-disable_gemm_plugin-enable_attention_plugin-disable_weight_only-float16-nb:1-use_cpp_runtime]
- unittest/kv_cache_manager_v2_tests/ # 4 min
- condition:
ranges:
system_gpu_count:

View File

@ -87,6 +87,7 @@ l0_b200:
- unittest/tools/test_layer_wise_benchmarks.py::test_nemotron_gen_dep[1]
- unittest/tools/test_layer_wise_benchmarks.py::test_qwen3_next_gen_tep[1]
- unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8
- unittest/kv_cache_manager_v2_tests/
- condition:
ranges:
system_gpu_count:

View File

@ -401,6 +401,7 @@ l0_h100:
- unittest/trt/model/test_gpt_e2e.py # 3 mins / 6 mins on H100
- unittest/trt/attention/test_gpt_attention_no_cache.py
- examples/test_gpt.py::test_gpt_oss_20b_lora_torch[gpt-oss-20b-lora-adapter_NIM_r8-gpt-oss-20b]
- unittest/kv_cache_manager_v2_tests/ # 4 min
- condition:
ranges:
system_gpu_count:

View File

@ -0,0 +1,206 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from collections.abc import Sequence
from functools import cached_property
from importlib.util import find_spec
from typing import TYPE_CHECKING, NamedTuple
if not TYPE_CHECKING and find_spec("kv_cache_manager_v2") is not None:
from kv_cache_manager_v2 import (
AttentionLayerConfig,
BeamIndex,
CudaStream,
DataRole,
KVCacheManagerConfig,
LayerId,
TokenIdExt,
_KVCache,
)
from kv_cache_manager_v2._common import BAD_PAGE_INDEX, NDEBUG, MemAddress
from kv_cache_manager_v2._utils import (
div_up,
exact_div,
get_uniform_attribute,
overlap,
typed_range,
value_or,
)
else:
from tensorrt_llm.runtime.kv_cache_manager_v2 import (
AttentionLayerConfig,
BeamIndex,
CudaStream,
DataRole,
KVCacheManagerConfig,
LayerId,
TokenIdExt,
_KVCache,
)
from tensorrt_llm.runtime.kv_cache_manager_v2._common import BAD_PAGE_INDEX, NDEBUG, MemAddress
from tensorrt_llm.runtime.kv_cache_manager_v2._utils import (
div_up,
exact_div,
get_uniform_attribute,
overlap,
typed_range,
value_or,
)
import os
from dynamic_path_manager import DynamicPathManager
with DynamicPathManager(os.path.dirname(os.path.abspath(__file__)), clear_cache=False):
from kernels import check_values, fill_values
class Step(NamedTuple):
kv_cache: _KVCache
input: list[TokenIdExt] # when empty, just check history
history: list[TokenIdExt]
class Role:
"""Constants for data roles in KV cache management."""
KEY = DataRole("key")
VALUE = DataRole("value")
KEY_BLOCK_QUANT = DataRole("key_block_quant")
VALUE_BLOCK_QUANT = DataRole("value_block_quant")
roles = (Role.KEY, Role.VALUE, Role.KEY_BLOCK_QUANT, Role.VALUE_BLOCK_QUANT)
class FakeEngine:
cfg: KVCacheManagerConfig
def __init__(self, config: KVCacheManagerConfig) -> None:
super().__init__()
self.cfg = config
@property
def tokens_per_block(self) -> int:
return self.cfg.tokens_per_block
@cached_property
def layers(self) -> dict[LayerId, AttentionLayerConfig]:
return {
layer.layer_id: layer
for layer in sorted(self.cfg.layers, key=lambda layer: layer.layer_id)
}
def execute(self, batch: Sequence[Step], stream: CudaStream) -> None:
assert batch
manager = get_uniform_attribute(batch, lambda step: step.kv_cache.manager)
for kv_cache, input, history in batch:
for layer_id, layer_cfg in self.layers.items():
for buf_id, buf in enumerate(layer_cfg.buffers):
role = buf.role
assert NDEBUG or buf.size == manager.get_page_stride(layer_id, role)
for beam in typed_range(kv_cache.beam_width):
# check history
self._check_pages(kv_cache, layer_id, buf_id, beam, history, stream)
# write new token
if input:
self._write_new_tokens(
kv_cache, len(history), layer_id, buf_id, beam, input, stream
)
def _check_pages(
self,
kv_cache: _KVCache,
layer_id: LayerId,
buf_id: int,
beam: BeamIndex,
history: Sequence[TokenIdExt],
stream: CudaStream,
):
manager = kv_cache.manager
tokens_per_block = self.tokens_per_block
layer_cfg = self.layers[layer_id]
buf = layer_cfg.buffers[buf_id]
role = buf.role
token_bytes = exact_div(buf.size, tokens_per_block)
pool = manager.get_mem_pool_base_address(layer_id, role)
stride = manager.get_page_stride(layer_id, role)
lc_id = manager._storage._layer_to_life_cycle_ids[layer_id]
pages = kv_cache.get_page_indices(lc_id, beam)
capacity = kv_cache.capacity
history_len = len(history)
assert len(history) == history_len
window = (
(0, capacity)
if layer_cfg.window_size is None
else (max(0, history_len + 1 - layer_cfg.window_size), capacity)
)
sink = value_or(layer_cfg.num_sink_tokens, 0)
# check history
for ordinal, page in enumerate(pages):
if page == BAD_PAGE_INDEX:
continue
page_range = (tokens_per_block * ordinal, tokens_per_block * (ordinal + 1))
need_page = overlap(page_range, (0, sink)) or overlap(page_range, window)
if need_page:
assert page != BAD_PAGE_INDEX
else:
assert kv_cache.history_length != history_len or page == BAD_PAGE_INDEX
addr = MemAddress(pool + stride * page)
tokens = history[tokens_per_block * ordinal : tokens_per_block * (ordinal + 1)]
check_values(addr, token_bytes, layer_id, buf_id, beam, tokens, stream)
def _write_new_tokens(
self,
kv_cache: _KVCache,
history_len: int,
layer_id: LayerId,
buf_id: int,
beam: BeamIndex,
input: Sequence[TokenIdExt],
stream: CudaStream,
):
manager = kv_cache.manager
tokens_per_block = self.tokens_per_block
layer_cfg = self.layers[layer_id]
buf = layer_cfg.buffers[buf_id]
role = buf.role
token_bytes = exact_div(buf.size, self.tokens_per_block)
pool = manager.get_mem_pool_base_address(layer_id, role)
stride = manager.get_page_stride(layer_id, role)
lc_id = manager._storage._layer_to_life_cycle_ids[layer_id]
pages = kv_cache.get_page_indices(lc_id, beam)[
: div_up(history_len + len(input), tokens_per_block)
]
capacity = kv_cache.capacity
input_range = (history_len, history_len + len(input))
assert input_range[1] <= capacity
ordinal_beg = input_range[0] // tokens_per_block
pages = itertools.islice(pages, ordinal_beg, None)
ordinal = None
for i, page in enumerate(pages):
ordinal = ordinal_beg + i
assert page != BAD_PAGE_INDEX
page_range = (tokens_per_block * ordinal, tokens_per_block * (ordinal + 1))
batch_range = tuple(i for i in overlap(input_range, page_range))
assert batch_range
tokens = input[(batch_range[0] - history_len) : (batch_range[1] - history_len)]
addr = MemAddress(
pool + stride * page + token_bytes * (batch_range[0] % tokens_per_block)
)
# print('layer_id={}, buf_id={}, beam={}, i={}, addr={}, tokens={}'.format(
# layer_id, buf_id, beam, i, addr, tokens))
fill_values(addr, token_bytes, layer_id, buf_id, beam, tokens, stream)
assert ordinal is None or ordinal + 1 == div_up(input_range[1], tokens_per_block)

View File

@ -0,0 +1,336 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import ctypes
from collections.abc import Sequence
from functools import lru_cache
from importlib.util import find_spec
from typing import TYPE_CHECKING, Iterator
import cuda.bindings.driver as drv
try:
from cuda.core import Kernel, ObjectCode, Program, ProgramOptions
except ImportError:
from cuda.core.experimental import Kernel, Program, ProgramOptions
from cuda.core.experimental._module import ObjectCode
if not TYPE_CHECKING and find_spec("kv_cache_manager_v2") is not None:
from kv_cache_manager_v2._common import CudaStream, LayerId, MemAddress, TokenIdExt
from kv_cache_manager_v2._utils import _unwrap, div_up, exact_div
else:
from tensorrt_llm.runtime.kv_cache_manager_v2._common import (
CudaStream,
LayerId,
MemAddress,
TokenIdExt,
)
from tensorrt_llm.runtime.kv_cache_manager_v2._utils import _unwrap, div_up, exact_div
_SLEEP_TIME_NS: int = 0
@contextlib.contextmanager
def enable_kernel_delay() -> Iterator[None]:
global _SLEEP_TIME_NS
_SLEEP_TIME_NS = 30_000
yield
_SLEEP_TIME_NS = 0
@lru_cache(maxsize=None)
def get_program(debug: bool, max_tokens: int, sleep_time: int) -> ObjectCode:
assert max_tokens > 0 and (max_tokens & (max_tokens - 1)) == 0, (
"max_tokens must be a power of 2"
)
code = r"""
#if !defined(__CUDACC_RTC__)
#include <cassert>
#include <cstdio>
#endif
#ifdef NDEBUG
__device__ inline void check(bool condition) {
if (!condition) {
asm volatile("trap;" ::: "memory");
}
}
#else
#define check assert
#endif
using uint32_t = unsigned int;
using uint16_t = unsigned short;
constexpr uint32_t sleepTime = SLEEP_TIME_NS;
struct alignas(16) Value {
uint32_t token;
uint32_t layer;
uint32_t role;
uint32_t beam;
__device__ inline bool operator==(Value const& other) const {
return token == other.token && layer == other.layer && role == other.role && beam == other.beam;
}
__device__ inline bool operator!=(Value const& other) const {
return !(*this == other);
}
};
constexpr uint32_t kMAX_TOKENS = MAX_TOKENS;
struct Tokens {
uint32_t tokens[kMAX_TOKENS];
};
extern "C" __global__ void fillValues(Value* data, uint32_t valuesPerToken, uint32_t layer,
uint32_t buf_id, uint32_t beam, __grid_constant__ const Tokens tokens, uint32_t numTokens) {
if (sleepTime > 0) {
__nanosleep(sleepTime);
}
check(numTokens <= kMAX_TOKENS);
auto const tid = (static_cast<uint32_t>(blockIdx.x) * blockDim.x) + threadIdx.x;
auto const idxToken = tid / valuesPerToken;
if (idxToken >= numTokens) {
return;
}
auto const token = tokens.tokens[idxToken];
auto const value = Value{token, layer, buf_id, beam};
data[tid] = value;
}
__device__ inline void assertEq(Value const& a, Value const& b) {
#ifndef NDEBUG
if (a != b) {
printf("(%d, %d, %d, %d) != (%d, %d, %d, %d)\n",
a.token, a.layer, a.role, a.beam,
b.token, b.layer, b.role, b.beam);
}
#endif
check(a == b);
}
extern "C" __global__ void checkValues(Value const* data, uint32_t valuesPerToken, uint32_t layer,
uint32_t buf_id, uint32_t beam, __grid_constant__ const Tokens tokens, uint32_t numTokens) {
if (sleepTime > 0) {
__nanosleep(sleepTime);
}
check(numTokens <= kMAX_TOKENS);
auto const tid = (static_cast<uint32_t>(blockIdx.x) * blockDim.x) + threadIdx.x;
auto const idxToken = tid / valuesPerToken;
if (idxToken >= numTokens) {
return;
}
auto const token = tokens.tokens[idxToken];
auto const value = Value{token, layer, buf_id, beam};
assertEq(data[tid], value);
}
"""
macros = [("MAX_TOKENS", str(max_tokens)), ("SLEEP_TIME_NS", str(sleep_time))]
program_options = ProgramOptions(std="c++17", lineinfo=True, debug=debug, define_macro=macros) # type: ignore[arg-type]
if not debug:
program_options.use_fast_math = True
prog = Program(code, code_type="c++", options=program_options)
mod = prog.compile("cubin", name_expressions=("fillValues", "checkValues"))
return mod
def get_kernel(name: str, num_tokens: int, sleep_time: int) -> tuple[Kernel, int]:
assert num_tokens > 0
@lru_cache(maxsize=None)
def impl(name: str, max_tokens: int, sleep_time: int) -> Kernel:
assert name in ("fillValues", "checkValues")
assert max_tokens != 0 and (max_tokens & (max_tokens - 1)) == 0, (
"max_tokens must be a power of 2"
)
debug = False
# debug = not NDEBUG
return get_program(debug, max_tokens, sleep_time).get_kernel(name)
# Round up to the next power of two
max_tokens = 2 ** ((num_tokens - 1).bit_length())
return impl(name, max_tokens, sleep_time), max_tokens
class Value(ctypes.Structure):
_fields_ = [
("token", ctypes.c_uint32),
("layer", ctypes.c_uint32),
("buf_id", ctypes.c_uint32),
("beam", ctypes.c_uint32),
]
def __eq__(self, other: object) -> bool:
if not isinstance(other, Value):
return NotImplemented
return (
self.token == other.token
and self.layer == other.layer
and self.buf_id == other.buf_id
and self.beam == other.beam
)
def __str__(self) -> str:
return (
f"Value(token={self.token}, layer={self.layer}, buf_id={self.buf_id}, beam={self.beam})"
)
@lru_cache(maxsize=None)
def _get_ctypes_struct(max_tokens: int) -> type[ctypes.Structure]:
class Tokens(ctypes.Structure):
_fields_ = [
("tokens", ctypes.c_uint32 * max_tokens),
]
Tokens.__name__ = f"Tokens_{max_tokens}"
return Tokens
def _make_tokens(tokens: Sequence[TokenIdExt], max_tokens: int) -> ctypes.Structure:
assert len(tokens) <= max_tokens
padded = list(tokens) + [0] * (max_tokens - len(tokens))
Tokens = _get_ctypes_struct(max_tokens)
return Tokens(
tokens=(ctypes.c_uint32 * max_tokens)(
*[
t if isinstance(t, int) else int.from_bytes(t[:4], "little", signed=False)
for t in padded
]
)
)
def fill_values(
address: MemAddress,
bytes_per_token: int,
layer: LayerId,
buf_id: int,
beam: int,
tokens: Sequence[TokenIdExt],
stream: CudaStream,
):
values_per_token = exact_div(bytes_per_token, ctypes.sizeof(Value))
num_tokens = len(tokens)
if num_tokens == 0:
return
kernel, max_tokens = get_kernel("fillValues", len(tokens), _SLEEP_TIME_NS)
args = (
address,
values_per_token,
layer,
buf_id,
beam,
_make_tokens(tokens, max_tokens),
num_tokens,
)
arg_types = (
ctypes.c_void_p,
ctypes.c_uint32,
ctypes.c_uint32,
ctypes.c_uint32,
ctypes.c_uint32,
None,
ctypes.c_uint32,
)
num_threads = values_per_token * num_tokens
cta_size = 256
_unwrap(
drv.cuLaunchKernel(
kernel._handle,
div_up(num_threads, cta_size),
1,
1,
cta_size,
1,
1,
0,
stream,
(args, arg_types),
0,
)
)
def check_values(
address: MemAddress,
bytes_per_token: int,
layer: LayerId,
buf_id: int,
beam: int,
tokens: Sequence[TokenIdExt],
stream: CudaStream,
):
values_per_token = exact_div(bytes_per_token, ctypes.sizeof(Value))
num_tokens = len(tokens)
if num_tokens == 0:
return
kernel, max_tokens = get_kernel("checkValues", len(tokens), _SLEEP_TIME_NS)
args = (
address,
values_per_token,
layer,
buf_id,
beam,
_make_tokens(tokens, max_tokens),
num_tokens,
)
arg_types = (
ctypes.c_void_p,
ctypes.c_uint32,
ctypes.c_uint32,
ctypes.c_uint32,
ctypes.c_uint32,
None,
ctypes.c_uint32,
)
num_threads = values_per_token * num_tokens
cta_size = 256
_unwrap(
drv.cuLaunchKernel(
kernel._handle,
div_up(num_threads, cta_size),
1,
1,
cta_size,
1,
1,
0,
stream,
(args, arg_types),
0,
)
)
def debug_dump_tokens(
addr: MemAddress, token_bytes: int, num_tokens: int, stream: CudaStream
) -> Iterator[Value]:
if num_tokens == 0:
return
val_size = ctypes.sizeof(Value)
values_per_token = exact_div(token_bytes, val_size)
host_buf = (Value * values_per_token * num_tokens)()
ptr = ctypes.addressof(host_buf)
_unwrap(drv.cuMemcpyDtoHAsync(ptr, addr, num_tokens * token_bytes, stream))
_unwrap(drv.cuStreamSynchronize(stream))
for i in range(num_tokens):
token = host_buf[i]
value = Value.from_buffer_copy(token[0])
for j in range(1, values_per_token):
assert token[j] == token[0]
yield value

View File

@ -0,0 +1,711 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import array
import functools
import gc
import itertools
import os
import random
import time
import unittest
from contextlib import contextmanager
from importlib.util import find_spec
from random import randbytes
from statistics import median
from typing import TYPE_CHECKING, Iterator, NamedTuple, cast
if not TYPE_CHECKING and find_spec("kv_cache_manager_v2") is not None:
from kv_cache_manager_v2 import (
AttentionLayerConfig,
BeamIndex,
BufferConfig,
CacheLevel,
CudaStream,
DiskCacheTierConfig,
GpuCacheTierConfig,
HostCacheTierConfig,
KVCacheManager,
KVCacheManagerConfig,
LayerGroupId,
LayerId,
TokenId,
TokenIdExt,
_KVCache,
)
from kv_cache_manager_v2._block_radix_tree import traverse_post_order
from kv_cache_manager_v2._common import GPU_LEVEL, PageStatus, SlidingWindowSize
from kv_cache_manager_v2._exceptions import OutOfPagesError
from kv_cache_manager_v2._utils import (
TemporaryCudaStream,
div_up,
init_cuda_once,
remove_if,
round_up,
typed_range,
unwrap_rawref,
)
else:
from tensorrt_llm.runtime.kv_cache_manager_v2 import (
AttentionLayerConfig,
BeamIndex,
BufferConfig,
CacheLevel,
CudaStream,
DiskCacheTierConfig,
GpuCacheTierConfig,
HostCacheTierConfig,
KVCacheManager,
KVCacheManagerConfig,
LayerGroupId,
LayerId,
TokenId,
TokenIdExt,
_KVCache,
)
from tensorrt_llm.runtime.kv_cache_manager_v2._block_radix_tree import traverse_post_order
from tensorrt_llm.runtime.kv_cache_manager_v2._common import (
GPU_LEVEL,
PageStatus,
SlidingWindowSize,
)
from tensorrt_llm.runtime.kv_cache_manager_v2._exceptions import OutOfPagesError
from tensorrt_llm.runtime.kv_cache_manager_v2._utils import (
TemporaryCudaStream,
div_up,
init_cuda_once,
remove_if,
round_up,
typed_range,
unwrap_rawref,
)
from dynamic_path_manager import DynamicPathManager
from parameterized import parameterized
with DynamicPathManager(os.path.dirname(os.path.abspath(__file__)), clear_cache=False):
from fake_engine import FakeEngine, Role, Step
from kernels import enable_kernel_delay
seed = int.from_bytes(os.urandom(8), "little")
print(f"seed: {seed}")
random.seed(seed)
DBG_PRINT = int(os.environ.get("DBG_PRINT", "0")) != 0
PRINT_TIME = int(os.environ.get("PRINT_TIME", "0")) != 0
@contextmanager
def ref_cycle_check_context():
"""Context manager for reference cycle check."""
import gc
gc.collect()
gc.garbage.clear()
gc.set_debug(gc.DEBUG_SAVEALL | gc.DEBUG_COLLECTABLE)
def on_gc_event(phase, info):
# phase is "start" or "stop"
# info contains keys like: "generation", "collected", "uncollectable", "duration"
if phase == "stop":
collected = info.get("collected", 0)
uncollectable = info.get("uncollectable", 0)
if collected != 0 or uncollectable != 0:
import pdb
pdb.set_trace()
assert collected == 0 and uncollectable == 0
gc.callbacks.append(on_gc_event)
try:
yield
finally:
gc.collect()
gc.callbacks.pop()
gc.set_debug(0)
def assert_no_ref_cycle(func):
"""Decorator to wrap test methods with GC debugging context."""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with ref_cycle_check_context():
result = func(self, *args, **kwargs)
return result
return wrapper
class TestKVCacheManagerV2(unittest.TestCase):
engine: FakeEngine
cfg: KVCacheManagerConfig
manager: KVCacheManager
_token_id_gen: Iterator[int]
def setUp(self) -> None:
init_cuda_once()
self._token_id_gen = itertools.count()
gc.collect()
gc.disable()
def tearDown(self) -> None:
gc.enable()
if hasattr(self, "manager"):
del self.manager
def next_token(self) -> TokenIdExt:
token_id = next(self._token_id_gen)
if token_id % 100 == 99:
return randbytes(32)
else:
return TokenId(token_id)
def prepare(
self,
gpu_quota: int,
host_quota: int,
disk_quota: int,
num_layers: int,
window_size: SlidingWindowSize,
sink_tokens: int,
tokens_per_block: int = 32,
kv_buf_size: int = 8192,
):
self._init_cfg(
tokens_per_block,
gpu_quota,
host_quota,
disk_quota,
num_layers,
window_size,
sink_tokens,
kv_buf_size,
)
self.engine = FakeEngine(self.cfg)
self.manager = KVCacheManager(self.cfg)
def _init_cfg(
self,
tokens_per_block: int,
gpu_quota: int,
host_quota: int,
disk_quota: int,
num_layers: int,
window_size: SlidingWindowSize,
sink_tokens: int,
kv_buf_size: int = 8192,
block_quant_buf_size: int | None = None,
):
layer_buffers = [
BufferConfig(role=Role.KEY, size=kv_buf_size),
BufferConfig(role=Role.VALUE, size=kv_buf_size),
]
if block_quant_buf_size is not None:
layer_buffers.extend(
[
BufferConfig(role=Role.KEY_BLOCK_QUANT, size=block_quant_buf_size),
BufferConfig(role=Role.VALUE_BLOCK_QUANT, size=block_quant_buf_size),
]
)
disk_path_candidates = ["/workspace/", "/tmp/nvidia-mps/", "/tmp"]
disk_path = next(p for p in disk_path_candidates if os.path.exists(p))
cache_tiers = [
GpuCacheTierConfig(quota=gpu_quota),
HostCacheTierConfig(quota=host_quota),
DiskCacheTierConfig(quota=disk_quota, path=disk_path),
]
self.cfg = KVCacheManagerConfig(
tokens_per_block=tokens_per_block,
vocab_size=4096,
cache_tiers=[t for t in cache_tiers if t.quota > 0],
layers=[
AttentionLayerConfig(
layer_id=layer_id,
buffers=layer_buffers,
sliding_window_size=window_size if layer_id % 2 == 0 else None,
num_sink_tokens=sink_tokens if layer_id % 2 == 0 else None,
)
for layer_id in typed_range(LayerId(num_layers))
],
)
class TestNoBatching(TestKVCacheManagerV2):
class Request(NamedTuple):
id: int
kv_cache: _KVCache
prompt: list[TokenIdExt]
decode_len: int
def new_request(
self, req_id: int, lora_task_id: int | None, prompt_len: int, decode_len: int
) -> Request:
prompt = [self.next_token() for _ in range(prompt_len)]
return self.Request(
req_id, self.manager.create_kv_cache(lora_task_id, prompt), prompt, decode_len
)
def run_request(self, req: Request, interval: int, refcheck: bool) -> float:
req_id, kv_cache, prompt, decode_len = req
assert kv_cache.status == _KVCache.Status.ACTIVE
stream = kv_cache.cuda_stream
tic = time.perf_counter()
# prefill
num_reused = kv_cache.num_committed_tokens
# workaround a mypyc bug: exception in property setter is not propagated
# kv_cache.capacity = round_up(len(prompt), interval)
if not kv_cache.resize(round_up(len(prompt), interval)):
raise OutOfPagesError("Not enough pages in GPU memory")
capacity = kv_cache.capacity
history = prompt[:num_reused]
input = prompt[num_reused:]
if refcheck:
self.engine.execute([Step(kv_cache, input, history)], stream)
if input:
kv_cache.commit(input)
history.extend(input)
# decode
for _ in range(decode_len):
required_capacity = len(history) + 1
if required_capacity > capacity:
kv_cache.commit(history[kv_cache.history_length :])
# workaround a mypyc bug: exception in property setter is not propagated
# kv_cache.capacity = round_up(required_capacity, interval)
if not kv_cache.resize(round_up(required_capacity, interval)):
raise OutOfPagesError("Not enough pages in GPU memory")
capacity = kv_cache.capacity
input_token = self.next_token()
if refcheck:
self.engine.execute([Step(kv_cache, [input_token], history)], stream)
history.append(input_token)
kv_cache.commit(history[kv_cache.history_length :])
# last check
if refcheck:
self.engine.execute([Step(kv_cache, [], history)], stream)
toc = time.perf_counter()
time_taken = toc - tic
# print(f"Time taken: {time_taken} seconds")
return time_taken
def run_naive(
self,
seq_len: int,
interval: int = 1,
refcheck: bool = True,
use_external_page_index_buf: bool = False,
) -> float:
prompt_len = 1
decode_len = seq_len - prompt_len
req_id = 0
lora_task_id = None
req0 = self.new_request(req_id, lora_task_id, prompt_len, decode_len)
if use_external_page_index_buf:
max_num_blocks = div_up(seq_len, self.cfg.tokens_per_block)
num_layer_groups = len(self.manager.layer_grouping)
page_indices = [
array.array("i", [-1]) * max_num_blocks for _ in range(num_layer_groups)
]
for id in range(num_layer_groups):
req0.kv_cache.set_page_index_buf(
BeamIndex(0), LayerGroupId(id), memoryview(page_indices[id])
)
with TemporaryCudaStream([]) as s:
stream = cast(CudaStream, s.handle)
kv_cache = req0.kv_cache
success = kv_cache.resume(stream)
assert success
time_taken = self.run_request(req0, interval, refcheck)
s.take_finish_event().synchronize()
kv_cache.close()
self.manager.clear_reusable_blocks()
return time_taken
@parameterized.expand([(False,), (True,)])
def test_shrink_capacity(self, use_external_page_index_buf: bool) -> None:
self.prepare(32 << 20, 32 << 20, 1 << 30, 36, 128, 1, kv_buf_size=32768)
seq_len = 32 * 10
req0 = self.new_request(0, None, 32, seq_len - 32)
if use_external_page_index_buf:
max_num_blocks = div_up(seq_len, self.cfg.tokens_per_block)
num_layer_groups = len(self.manager.layer_grouping)
page_indices = [
array.array("i", [-1]) * max_num_blocks for _ in range(num_layer_groups)
]
for id in range(num_layer_groups):
req0.kv_cache.set_page_index_buf(
BeamIndex(0), LayerGroupId(id), memoryview(page_indices[id])
)
with TemporaryCudaStream([]) as s:
stream = cast(CudaStream, s.handle)
kv_cache = req0.kv_cache
success = kv_cache.resume(stream)
assert success
success = kv_cache.resize(seq_len)
assert success
for capacity in range(seq_len, len(req0.prompt), -1):
success = kv_cache.resize(capacity)
assert success
s.take_finish_event()
kv_cache.close()
def test_small_quota(self) -> None:
self.prepare(5619712, 0, 0, 8, None, 0)
assert self.manager.get_quota(cast(CacheLevel, GPU_LEVEL)) >= 5619712
# @assert_no_ref_cycle
def test_sol_mem_utilization(self) -> None:
self.prepare(32 << 20, 32 << 20, 1 << 30, 36, 128, 1, kv_buf_size=32768)
# if we have n blocks, we need 8192*2*18*(1+5+n) bytes of memory. For the (1+5+n), 1 is for sink
# blocks, 5 is for SWA (window=128), n is for full attention.
max_seq_len = 32 * 22 # 23 blocks will require more than 32MB memory
seq_len = max_seq_len
# create a request and suspend it. It shall not consume any GPU memory after suspend.
req0 = self.new_request(0, None, 256, seq_len - 256)
with TemporaryCudaStream([]) as s:
stream = cast(CudaStream, s.handle)
success = req0.kv_cache.resume(stream)
assert success
self.run_request(req0, 32, False)
s.take_finish_event()
req0.kv_cache.suspend()
# run another request that will take all the GPU memory
req1 = self.new_request(0, None, 256, seq_len - 256)
with TemporaryCudaStream([]) as s:
stream = cast(CudaStream, s.handle)
success = req1.kv_cache.resume(stream)
assert success
self.run_request(req1, 1, True)
s.take_finish_event()
req1.kv_cache.close()
req0.kv_cache.close()
# run another longer request and expect OutOfPagesError
# This also tests eviction to disk.
self.assertRaises(OutOfPagesError, lambda: self.run_naive(seq_len + 1, 1, False))
@parameterized.expand([(1,), (2,), (4,)])
# @assert_no_ref_cycle
def test_cache_reuse(self, num_reusable_requests: int) -> None:
self.prepare(32 << 20, 32 << 20, 1 << 30, 36, 128, 1, kv_buf_size=32768)
# if we have n blocks, we need 8192*2*18*(1+5+n) bytes of memory. For the (1+5+n), 1 is for sink
# blocks, 5 is for SWA (window=128), n is for full attention.
max_seq_len = 32 * 22 # 23 blocks will require more than 32MB memory
seq_len = max_seq_len
req_id_gen = itertools.count()
reusable_requests = []
with TemporaryCudaStream([]) as s:
stream = cast(CudaStream, s.handle)
for _ in range(num_reusable_requests):
req = self.new_request(next(req_id_gen), None, 256, seq_len - 256)
reusable_requests.append(req)
success = req.kv_cache.resume(stream)
assert success
self.run_request(req, 32, True)
req.kv_cache.close()
s.take_finish_event()
for root_block in self.manager._radix_tree.next.values():
for block0 in root_block.next.values():
for block in traverse_post_order(block0):
for page in block.storage:
if page is not None:
assert unwrap_rawref(page).status == PageStatus.DROPPABLE
req0 = reusable_requests[0]
prompt1 = req0.kv_cache._committed_tokens[: (seq_len // 2 - 7)]
# request id must be same as req0 because we wrote it into the kv cache.
req1 = self.Request(
next(req_id_gen),
self.manager.create_kv_cache(None, prompt1),
prompt1,
seq_len - len(prompt1),
)
assert req1.kv_cache.num_committed_tokens == len(prompt1)
with TemporaryCudaStream([]) as s:
stream = cast(CudaStream, s.handle)
success = req1.kv_cache.resume(stream)
assert success
self.run_request(req1, 32, True)
s.take_finish_event()
req1.kv_cache.close()
self.manager.clear_reusable_blocks()
@parameterized.expand([(False,), (True,)])
# @assert_no_ref_cycle
def test_naive(self, use_external_page_index_buf: bool) -> None:
self.prepare(256 << 20, 256 << 20, 1 << 30, 36, 128, 48)
self.run_naive(512, 1, True, use_external_page_index_buf)
@parameterized.expand([(2**i, False) for i in range(12)])
# @parameterized.expand([(32, True)])
# @assert_no_ref_cycle
def test_naive_perf(self, interval, profile: bool) -> None:
if not PRINT_TIME:
self.skipTest("Skipping perf test")
self.prepare(256 << 20, 256 << 20, 1 << 30, 36, 128, 48)
seq_len = 10240
self.run_naive(seq_len, interval, False) # warm up for numba jit
profiler = None
if profile:
import cProfile
profiler = cProfile.Profile()
profiler.enable()
time_taken = [
self.run_naive(seq_len, interval, False) for _ in range(11 if profiler is None else 1)
]
median_time_taken = median(time_taken)
if PRINT_TIME:
print(
f"Throughput: {round(seq_len / median_time_taken)} tokens/sec for interval {interval}"
)
if profiler is not None:
profiler.disable()
profiler.print_stats(sort="cumtime")
profiler.dump_stats("profiler.prof")
class TestBatching(TestKVCacheManagerV2):
num_requests: int
avg_length: int
past_sequences: list[list[TokenIdExt]]
seq_len_dict: dict[_KVCache, int]
batch: list[Step]
suspended: list[Step]
num_created: int
num_finished: int
req_id_gen: Iterator[int]
acc_num_prompt_tokens: int
acc_num_decode_tokens: int
interval: int
enable_reuse: bool
def setUp(self) -> None:
super().setUp()
self.past_sequences = list[list[TokenIdExt]]()
self.seq_len_dict = dict[_KVCache, int]()
self.batch = list[Step]()
self.suspended = list[Step]()
self.num_finished = 0
self.num_created = 0
self.req_id_gen = itertools.count()
self.acc_num_prompt_tokens = 0
self.acc_num_decode_tokens = 0
self.enable_reuse = False
def gen_request(self) -> Step:
if self.num_created >= self.num_requests:
raise ValueError("Too many requests created")
token_id_gen = cast(Iterator[TokenId], self._token_id_gen)
def gen_length() -> int:
return random.randint(int(self.avg_length * 0.6), int(self.avg_length * 1.4))
if self.enable_reuse:
if len(self.past_sequences) >= 32 and random.random() < 0.2:
# continued multi-round dialog
prompt = random.choice(self.past_sequences) + [
next(token_id_gen) for _ in range(gen_length())
]
else:
# new dialog
if len(self.past_sequences) < 32 or random.random() < 0.5:
# completely new prompt
prompt = [next(token_id_gen) for _ in range(gen_length())]
else:
# with reused tokens
reused = random.choice(self.past_sequences)
prompt = reused[: random.randint(0, min(gen_length(), len(reused)))] + [
next(token_id_gen) for _ in range(gen_length())
]
else:
prompt = [next(token_id_gen) for _ in range(gen_length())]
decode_len = gen_length()
lora_task_id = None
kv_cache = self.manager.create_kv_cache(
lora_task_id, prompt[:-1] if self.enable_reuse else None, id=next(self.req_id_gen)
)
DBG_PRINT and print( # type: ignore[arg-type]
f"created {kv_cache.id} with {kv_cache.num_committed_tokens} tokens reused"
)
history = prompt[: kv_cache.num_committed_tokens]
input = prompt[kv_cache.num_committed_tokens :]
seq_len = len(prompt) + decode_len
self.seq_len_dict[kv_cache] = seq_len
self.num_created += 1
assert input
self.acc_num_prompt_tokens += len(prompt)
self.acc_num_decode_tokens += decode_len
return Step(kv_cache, input, history)
def update_batch(self, stream: CudaStream) -> None:
for s in self.batch:
assert s.input
if self.enable_reuse:
s.kv_cache.commit(s.input)
else:
s.kv_cache.history_length += len(s.input)
s.history.extend(s.input)
s.input.clear()
# remove finished requests first
removed = remove_if(
self.batch,
lambda step: len(step.history) >= self.seq_len_dict[step.kv_cache],
)
for kv_cache, _, _ in removed:
seq_len = self.seq_len_dict[kv_cache]
if seq_len < self.avg_length * 3:
self.past_sequences.append(kv_cache._committed_tokens[:seq_len])
kv_cache.close()
self.seq_len_dict.pop(kv_cache)
self.num_finished += 1
# fill input for remaining requests and increase capacity for them
token_id_gen = cast(Iterator[TokenId], self._token_id_gen)
for s in self.batch:
assert not s.input
length = min(self.interval, self.seq_len_dict[s.kv_cache] - len(s.history))
s.input.extend(next(token_id_gen) for _ in range(length))
for i in itertools.count():
if i >= len(self.batch):
break
s = self.batch[i]
while i < len(self.batch) and not s.kv_cache.resize(
len(s.history) + len(s.input), None
):
last = self.batch.pop()
DBG_PRINT and print(f"suspending {last.kv_cache.id}") # type: ignore[arg-type]
last.kv_cache.suspend()
self.suspended.append(last)
# try to add new requests
suspended = self.suspended
while suspended or self.num_created < self.num_requests:
if not suspended:
assert self.num_created < self.num_requests
suspended.append(self.gen_request())
if suspended:
step = suspended[-1]
kv_cache = step.kv_cache
ok = kv_cache.resume(stream)
if (
ok
and not self.enable_reuse
and kv_cache._commit_state == _KVCache.CommitState.ALLOWED
):
kv_cache.stop_committing()
ok = ok and kv_cache.resize(len(step.history) + len(step.input), None)
if ok:
DBG_PRINT and print(f"activating {step.kv_cache.id}") # type: ignore[arg-type]
self.batch.append(suspended.pop())
else:
if kv_cache.status == _KVCache.Status.ACTIVE:
kv_cache.suspend()
break
DBG_PRINT and print( # type: ignore[arg-type]
f"update_batch: found {len(removed)} finished requests, now with {len(self.batch)} requests"
)
@parameterized.expand(
[
(1000, 1000, 1024, True, 32, 32),
(1000, 1000, 1024, True, 1, 32),
(10000, 1000, 1024, True, 32, 32),
(100, 100, 128, False, 1, 128),
(100, 100, 128, False, 4, 64),
]
)
# @assert_no_ref_cycle
def test_inflight_batching(
self,
num_requests: int,
avg_length: int,
gpu_quota_mb: int,
skip_execution: bool,
interval: int,
tokens_per_block: int,
):
self.prepare(
gpu_quota_mb << 20, 4 << 30, 0 << 30, 36, 128, 0, tokens_per_block=tokens_per_block
)
self.num_requests = num_requests
self.avg_length = avg_length
self.interval = interval
profile = False
profiler = None
if profile:
import cProfile
profiler = cProfile.Profile()
profiler.enable()
tic = time.perf_counter()
with TemporaryCudaStream([]) as s, enable_kernel_delay():
stream = cast(CudaStream, s.handle)
i = itertools.count()
self.update_batch(stream)
while self.num_finished < self.num_requests:
DBG_PRINT and print( # type: ignore[arg-type]
f"Executing batch {next(i)} with size {len(self.batch)}"
)
assert self.batch
if not skip_execution:
self.engine.execute(self.batch, stream)
self.update_batch(stream)
toc = time.perf_counter()
if profiler is not None:
profiler.disable()
profiler.print_stats(sort="cumtime")
profiler.dump_stats("profiler.prof")
if DBG_PRINT or PRINT_TIME:
print(
f"Time taken: {toc - tic} seconds (num_prompt_tokens: {self.acc_num_prompt_tokens}, "
f"num_decode_tokens: {self.acc_num_decode_tokens})"
)
s.take_finish_event().synchronize()
class TestDisagg(TestKVCacheManagerV2):
@parameterized.expand([512])
# @assert_no_ref_cycle
def test_disagg(self, prompt_len: int) -> None:
self.prepare(128 << 20, 128 << 20, 1 << 30, 36, 128, 0)
lora_task_id = None
prompt = [self.next_token() for _ in range(prompt_len)]
kv_cache = self.manager.create_kv_cache(lora_task_id, prompt)
assert kv_cache.num_committed_tokens == 0
with TemporaryCudaStream([]) as stream:
success = kv_cache.resume(cast(CudaStream, stream.handle))
assert success
success = kv_cache.resize(prompt_len, prompt_len)
assert success
def transfer() -> None:
return None
transfer()
kv_cache.commit(prompt)
kv_cache.close()
stream.take_finish_event().synchronize()
if __name__ == "__main__":
unittest.main()