TensorRT-LLMs/tests/unittest/kv_cache_manager_v2_tests/kernels.py
Yao Yao 6f07fa81d7
[TRTLLM-7738][feat] Adding implementation of KVCacheManagerV2 (#10736)
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>

KVCacheManagerV2 is a new python-based implementation of the KV cache manager, featuring cleaner API, better abstraction and better code quality without the accumulated legacy.
2026-01-24 04:48:39 -05:00

337 lines
9.6 KiB
Python

# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import ctypes
from collections.abc import Sequence
from functools import lru_cache
from importlib.util import find_spec
from typing import TYPE_CHECKING, Iterator
import cuda.bindings.driver as drv
try:
from cuda.core import Kernel, ObjectCode, Program, ProgramOptions
except ImportError:
from cuda.core.experimental import Kernel, Program, ProgramOptions
from cuda.core.experimental._module import ObjectCode
if not TYPE_CHECKING and find_spec("kv_cache_manager_v2") is not None:
from kv_cache_manager_v2._common import CudaStream, LayerId, MemAddress, TokenIdExt
from kv_cache_manager_v2._utils import _unwrap, div_up, exact_div
else:
from tensorrt_llm.runtime.kv_cache_manager_v2._common import (
CudaStream,
LayerId,
MemAddress,
TokenIdExt,
)
from tensorrt_llm.runtime.kv_cache_manager_v2._utils import _unwrap, div_up, exact_div
_SLEEP_TIME_NS: int = 0
@contextlib.contextmanager
def enable_kernel_delay() -> Iterator[None]:
global _SLEEP_TIME_NS
_SLEEP_TIME_NS = 30_000
yield
_SLEEP_TIME_NS = 0
@lru_cache(maxsize=None)
def get_program(debug: bool, max_tokens: int, sleep_time: int) -> ObjectCode:
assert max_tokens > 0 and (max_tokens & (max_tokens - 1)) == 0, (
"max_tokens must be a power of 2"
)
code = r"""
#if !defined(__CUDACC_RTC__)
#include <cassert>
#include <cstdio>
#endif
#ifdef NDEBUG
__device__ inline void check(bool condition) {
if (!condition) {
asm volatile("trap;" ::: "memory");
}
}
#else
#define check assert
#endif
using uint32_t = unsigned int;
using uint16_t = unsigned short;
constexpr uint32_t sleepTime = SLEEP_TIME_NS;
struct alignas(16) Value {
uint32_t token;
uint32_t layer;
uint32_t role;
uint32_t beam;
__device__ inline bool operator==(Value const& other) const {
return token == other.token && layer == other.layer && role == other.role && beam == other.beam;
}
__device__ inline bool operator!=(Value const& other) const {
return !(*this == other);
}
};
constexpr uint32_t kMAX_TOKENS = MAX_TOKENS;
struct Tokens {
uint32_t tokens[kMAX_TOKENS];
};
extern "C" __global__ void fillValues(Value* data, uint32_t valuesPerToken, uint32_t layer,
uint32_t buf_id, uint32_t beam, __grid_constant__ const Tokens tokens, uint32_t numTokens) {
if (sleepTime > 0) {
__nanosleep(sleepTime);
}
check(numTokens <= kMAX_TOKENS);
auto const tid = (static_cast<uint32_t>(blockIdx.x) * blockDim.x) + threadIdx.x;
auto const idxToken = tid / valuesPerToken;
if (idxToken >= numTokens) {
return;
}
auto const token = tokens.tokens[idxToken];
auto const value = Value{token, layer, buf_id, beam};
data[tid] = value;
}
__device__ inline void assertEq(Value const& a, Value const& b) {
#ifndef NDEBUG
if (a != b) {
printf("(%d, %d, %d, %d) != (%d, %d, %d, %d)\n",
a.token, a.layer, a.role, a.beam,
b.token, b.layer, b.role, b.beam);
}
#endif
check(a == b);
}
extern "C" __global__ void checkValues(Value const* data, uint32_t valuesPerToken, uint32_t layer,
uint32_t buf_id, uint32_t beam, __grid_constant__ const Tokens tokens, uint32_t numTokens) {
if (sleepTime > 0) {
__nanosleep(sleepTime);
}
check(numTokens <= kMAX_TOKENS);
auto const tid = (static_cast<uint32_t>(blockIdx.x) * blockDim.x) + threadIdx.x;
auto const idxToken = tid / valuesPerToken;
if (idxToken >= numTokens) {
return;
}
auto const token = tokens.tokens[idxToken];
auto const value = Value{token, layer, buf_id, beam};
assertEq(data[tid], value);
}
"""
macros = [("MAX_TOKENS", str(max_tokens)), ("SLEEP_TIME_NS", str(sleep_time))]
program_options = ProgramOptions(std="c++17", lineinfo=True, debug=debug, define_macro=macros) # type: ignore[arg-type]
if not debug:
program_options.use_fast_math = True
prog = Program(code, code_type="c++", options=program_options)
mod = prog.compile("cubin", name_expressions=("fillValues", "checkValues"))
return mod
def get_kernel(name: str, num_tokens: int, sleep_time: int) -> tuple[Kernel, int]:
assert num_tokens > 0
@lru_cache(maxsize=None)
def impl(name: str, max_tokens: int, sleep_time: int) -> Kernel:
assert name in ("fillValues", "checkValues")
assert max_tokens != 0 and (max_tokens & (max_tokens - 1)) == 0, (
"max_tokens must be a power of 2"
)
debug = False
# debug = not NDEBUG
return get_program(debug, max_tokens, sleep_time).get_kernel(name)
# Round up to the next power of two
max_tokens = 2 ** ((num_tokens - 1).bit_length())
return impl(name, max_tokens, sleep_time), max_tokens
class Value(ctypes.Structure):
_fields_ = [
("token", ctypes.c_uint32),
("layer", ctypes.c_uint32),
("buf_id", ctypes.c_uint32),
("beam", ctypes.c_uint32),
]
def __eq__(self, other: object) -> bool:
if not isinstance(other, Value):
return NotImplemented
return (
self.token == other.token
and self.layer == other.layer
and self.buf_id == other.buf_id
and self.beam == other.beam
)
def __str__(self) -> str:
return (
f"Value(token={self.token}, layer={self.layer}, buf_id={self.buf_id}, beam={self.beam})"
)
@lru_cache(maxsize=None)
def _get_ctypes_struct(max_tokens: int) -> type[ctypes.Structure]:
class Tokens(ctypes.Structure):
_fields_ = [
("tokens", ctypes.c_uint32 * max_tokens),
]
Tokens.__name__ = f"Tokens_{max_tokens}"
return Tokens
def _make_tokens(tokens: Sequence[TokenIdExt], max_tokens: int) -> ctypes.Structure:
assert len(tokens) <= max_tokens
padded = list(tokens) + [0] * (max_tokens - len(tokens))
Tokens = _get_ctypes_struct(max_tokens)
return Tokens(
tokens=(ctypes.c_uint32 * max_tokens)(
*[
t if isinstance(t, int) else int.from_bytes(t[:4], "little", signed=False)
for t in padded
]
)
)
def fill_values(
address: MemAddress,
bytes_per_token: int,
layer: LayerId,
buf_id: int,
beam: int,
tokens: Sequence[TokenIdExt],
stream: CudaStream,
):
values_per_token = exact_div(bytes_per_token, ctypes.sizeof(Value))
num_tokens = len(tokens)
if num_tokens == 0:
return
kernel, max_tokens = get_kernel("fillValues", len(tokens), _SLEEP_TIME_NS)
args = (
address,
values_per_token,
layer,
buf_id,
beam,
_make_tokens(tokens, max_tokens),
num_tokens,
)
arg_types = (
ctypes.c_void_p,
ctypes.c_uint32,
ctypes.c_uint32,
ctypes.c_uint32,
ctypes.c_uint32,
None,
ctypes.c_uint32,
)
num_threads = values_per_token * num_tokens
cta_size = 256
_unwrap(
drv.cuLaunchKernel(
kernel._handle,
div_up(num_threads, cta_size),
1,
1,
cta_size,
1,
1,
0,
stream,
(args, arg_types),
0,
)
)
def check_values(
address: MemAddress,
bytes_per_token: int,
layer: LayerId,
buf_id: int,
beam: int,
tokens: Sequence[TokenIdExt],
stream: CudaStream,
):
values_per_token = exact_div(bytes_per_token, ctypes.sizeof(Value))
num_tokens = len(tokens)
if num_tokens == 0:
return
kernel, max_tokens = get_kernel("checkValues", len(tokens), _SLEEP_TIME_NS)
args = (
address,
values_per_token,
layer,
buf_id,
beam,
_make_tokens(tokens, max_tokens),
num_tokens,
)
arg_types = (
ctypes.c_void_p,
ctypes.c_uint32,
ctypes.c_uint32,
ctypes.c_uint32,
ctypes.c_uint32,
None,
ctypes.c_uint32,
)
num_threads = values_per_token * num_tokens
cta_size = 256
_unwrap(
drv.cuLaunchKernel(
kernel._handle,
div_up(num_threads, cta_size),
1,
1,
cta_size,
1,
1,
0,
stream,
(args, arg_types),
0,
)
)
def debug_dump_tokens(
addr: MemAddress, token_bytes: int, num_tokens: int, stream: CudaStream
) -> Iterator[Value]:
if num_tokens == 0:
return
val_size = ctypes.sizeof(Value)
values_per_token = exact_div(token_bytes, val_size)
host_buf = (Value * values_per_token * num_tokens)()
ptr = ctypes.addressof(host_buf)
_unwrap(drv.cuMemcpyDtoHAsync(ptr, addr, num_tokens * token_bytes, stream))
_unwrap(drv.cuStreamSynchronize(stream))
for i in range(num_tokens):
token = host_buf[i]
value = Value.from_buffer_copy(token[0])
for j in range(1, values_per_token):
assert token[j] == token[0]
yield value