TensorRT-LLMs/examples/llm-api/llm_kv_cache_connector.py
Richard Huo dcd110cfac
[None][chore] add TorchLlmArgs to the connector api (#7493)
Signed-off-by: richardhuo-nv <rihuo@nvidia.com>
2025-09-09 09:05:59 -04:00

248 lines
8.6 KiB
Python

### :title KV Cache Connector
### :order 6
### :section Customization
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import TemporaryDirectory
import click
import torch
from tensorrt_llm import LLM, SamplingParams, logger
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
# This is a simple example of the use of the KV cache connector.
# It persists KV cache contents into a folder, and can load them back on subsequent runs.
# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface.
# NOTE: This example connector implementation is NOT suitable for production use.
CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"
@dataclass
class PersistentKvCacheConnectorMetadata:
load: list[tuple[str, int]] = field(default_factory=list)
save: list[tuple[str, int]] = field(default_factory=list)
class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
def __init__(self, llm_args: TorchLlmArgs):
super().__init__(llm_args)
self.kv_cache_tensor = None
def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
assert self.kv_cache_tensor is None, "KV cache tensor already registered"
self.kv_cache_tensor = kv_cache_tensor
def start_load_kv(self, stream: torch.cuda.Stream):
# Do all loads synchronously, and blockwise.
for path, block_id in self._metadata.load:
cpu_tensor = torch.load(path, map_location="cpu")
# Copy into the device block.
self.kv_cache_tensor[block_id].copy_(cpu_tensor, non_blocking=False)
def wait_for_layer_load(self, layer_idx: int, stream: torch.cuda.Stream):
pass
def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
pass
def wait_for_save(self, stream: torch.cuda.Stream):
# Make sure the forward pass is complete before beginning our save.
stream.synchronize()
for path, block_id in self._metadata.save:
cpu_tensor = self.kv_cache_tensor[block_id].cpu()
# Don't write anything if this specific block already exists.
if Path(path).exists():
continue
# Do a blocking save to the file. This way, we only return once all saves are complete.
torch.save(cpu_tensor, path)
def get_finished(
self, finished_gen_req_ids: list[int],
started_loading_req_ids: list[int]) -> tuple[list[int], list[int]]:
return [], []
class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
def __init__(self, llm_args: TorchLlmArgs):
super().__init__(llm_args)
self.block_size = self._llm_args.kv_cache_config.tokens_per_block
self.pending_loads = {}
self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
"./connector_cache")
os.makedirs(self.cache_folder, exist_ok=True)
def build_connector_meta(self, scheduler_output: SchedulerOutput):
# NOTE: This is a simplified implementation, and does not work with chunked prefill.
metadata = PersistentKvCacheConnectorMetadata()
for req in scheduler_output.new_requests:
# If we don't have any pending loads for this request, we can skip it.
if req.request_id not in self.pending_loads:
continue
num_computed_blocks = req.computed_position // self.block_size
block_ids = req.new_block_ids
pending_load = self.pending_loads[req.request_id]
for file_path, block_pos in zip(
pending_load, range(num_computed_blocks, len(block_ids))):
metadata.load.append((file_path, block_ids[block_pos]))
# Break up the remainder of the token sequence into chunks.
chunks = self._chunk_tokens(req.new_tokens)
# For each chunk that isn't already on device, and isn't in our connector cache, we need to save it.
for block_pos in range(num_computed_blocks + len(pending_load),
len(block_ids)):
if len(chunks[block_pos]) == self.block_size:
hashed_tokens = self._hash_tokens(chunks[block_pos])
file_path = self._file_path(hashed_tokens)
metadata.save.append((file_path, block_ids[block_pos]))
self.pending_loads = {}
return metadata
def _hash_tokens(self, tokens: list[int]) -> int:
return abs(hash(tuple(tokens)))
def _file_path(self, hash_value: int) -> Path:
return Path(self.cache_folder) / f"{hash_value}.pt"
def _chunk_tokens(self, tokens: list[int]) -> list[list[int]]:
return [
tokens[i:i + self.block_size]
for i in range(0, len(tokens), self.block_size)
]
def get_num_new_matched_tokens(
self, request: LlmRequest,
num_computed_tokens: int) -> tuple[int, bool]:
self.pending_loads[request.request_id] = []
# Don't bother with sequences with partial matches.
if (num_computed_tokens % self.block_size) != 0:
return 0, False
computed_blocks = num_computed_tokens // self.block_size
# Get all the tokens that don't have a cache hit on device.
remaining_tokens = request.get_tokens(0)[computed_blocks *
self.block_size:]
remaining_chunks = self._chunk_tokens(remaining_tokens)
# For each chunk, check if it exists in our cache.
for chunk in remaining_chunks:
# Only do full blocks.
if len(chunk) == self.block_size:
hashed_tokens = self._hash_tokens(chunk)
file_path = self._file_path(hashed_tokens)
# If we get a cache hit, we want to load it into device.
# Otherwise, we can stop looking.
if file_path.exists():
self.pending_loads[request.request_id].append(file_path)
else:
break
logger.info(
f"KV CONNECTOR: Matched {len(self.pending_loads[request.request_id])} blocks for request {request.request_id}"
)
return len(
self.pending_loads[request.request_id]) * self.block_size, False
def request_finished(self, request: LlmRequest,
cache_block_ids: list[int]) -> bool:
# We don't do any asynchronous saving, so always return False
return False
def update_state_after_alloc(self, request: LlmRequest,
block_ids: list[int]):
pass
@click.command()
@click.argument("model", type=str)
def main(model: str):
sys.path.append(os.path.join(
os.path.dirname(__file__),
"..",
))
this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]
kv_connector_config = KvCacheConnectorConfig(
connector_module=this_module,
connector_scheduler_class="PersistentKvCacheConnectorLeader",
connector_worker_class="PersistentKvCacheConnectorWorker",
)
connector_cache_dir = TemporaryDirectory()
os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name
llm = LLM(model=model,
backend="pytorch",
cuda_graph_config=None,
kv_connector_config=kv_connector_config)
test_text = (
"Nvidia Corporation is an American technology company headquartered in Santa Clara, California."
"Founded in 1993 by Jensen Huang, Chris Malachowsky, and Curtis Priem, it develops graphics processing units (GPUs), "
"system on a chips (SoCs), and application programming interfaces (APIs) for data science, high-performance computing, "
"and mobile and automotive applications. Tell me about the company.")
sampling_params = SamplingParams(max_tokens=32)
output = llm.generate([test_text], sampling_params)
text0 = output[0].outputs[0].text
print("First output: ", text0)
print("Loading new LLM instance...")
del llm
llm = LLM(model=model,
backend="pytorch",
cuda_graph_config=None,
kv_connector_config=kv_connector_config)
output = llm.generate([test_text], sampling_params)
text1 = output[0].outputs[0].text
print("Second output (using connector cache): ", text1)
assert text0 == text1
connector_cache_dir.cleanup()
if __name__ == "__main__":
main()