TensorRT-LLMs/tests/unittest/llmapi/test_llm_kv_cache_events.py
Simeng Liu f21e2b3329
[TRTLLM-9601][feat] Expose mmKeys for multimodal to integrate with dynamo. (#9604)
Signed-off-by: SimengLiu-nv <simengl@nvidia.com>
2025-12-15 08:42:30 +08:00

384 lines
13 KiB
Python

import asyncio
import time
import pytest
from utils.util import skip_single_gpu
import tensorrt_llm
from tensorrt_llm import LLM
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._utils import KVCacheEventSerializer
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.sampling_params import SamplingParams
from tensorrt_llm.scheduling_params import SchedulingParams
from .test_llm import get_model_path
default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
llama_model_path = get_model_path(default_model_name)
global_kvcache_config = KvCacheConfig(free_gpu_memory_fraction=0.4,
event_buffer_max_size=1024,
enable_block_reuse=True,
onboard_blocks=True,
max_tokens=256)
def create_kv_cache_manager():
num_layers = 2
num_kv_heads = 2
head_dim = 128
tokens_per_block = 64
max_seq_len = 1024
max_batch_size = 1
mapping = Mapping()
return KVCacheManager(
kv_cache_config=global_kvcache_config,
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.CacheType.
SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
mapping=mapping,
)
def create_llm(tensor_parallel_size=1):
return LLM(model=llama_model_path,
tensor_parallel_size=tensor_parallel_size,
kv_cache_config=global_kvcache_config,
enable_autotuner=False)
def create_llm_request(id, input_tokens, new_tokens=1):
sampling_params = SamplingParams()
req = LlmRequest(request_id=id,
max_new_tokens=new_tokens,
input_tokens=input_tokens,
sampling_config=tensorrt_llm.bindings.SamplingConfig(
sampling_params._get_sampling_config()),
is_streaming=False)
return req
def flush_events(kv_cache_manager):
kv_cache_manager.flush_iteration_events()
time.sleep(0.001)
def test_kv_cache_event_data_serialization():
kv_cache_manager = create_kv_cache_manager()
flush_events(kv_cache_manager)
events = kv_cache_manager.get_latest_events(10)
serialized_event = KVCacheEventSerializer.serialize(events)
assert len(serialized_event) == 1 and serialized_event[0][
"event_id"] == 0 and serialized_event[0]["window_size"] == 256
assert serialized_event[0]["data"]["type"] == "created"
assert len(serialized_event[0]["data"]["num_blocks_per_cache_level"]) == 2
req = create_llm_request(0, [1, 2, 3, 4, 5])
kv_cache_manager.impl.add_sequence(req.py_request_id, req.prompt_len, 1,
req)
kv_cache_manager.free_resources(req)
flush_events(kv_cache_manager)
events = kv_cache_manager.get_latest_events(10)
serialized_event = KVCacheEventSerializer.serialize(events)
assert serialized_event[0]["data"]["type"] == "stored"
assert serialized_event[0]["data"]["parent_hash"] is None
assert len(serialized_event[0]["data"]["blocks"]) == 1
assert len(serialized_event[0]["data"]["blocks"][0]["tokens"]) == 4
# Verify mm_keys field exists (empty for text-only requests)
assert "mm_keys" in serialized_event[0]["data"]["blocks"][0]
assert serialized_event[0]["data"]["blocks"][0]["mm_keys"] == []
req2 = create_llm_request(1, [1, 2, 3, 4, 5])
kv_cache_manager.impl.add_sequence(req2.py_request_id, req2.prompt_len, 1,
req2)
kv_cache_manager.free_resources(req2)
flush_events(kv_cache_manager)
events = kv_cache_manager.get_latest_events(10)
serialized_event = KVCacheEventSerializer.serialize(events)
def test_mm_keys_serialization():
"""Test serialization of multimodal keys (mm_keys) in KV cache events."""
# Test _mm_key_to_json with a mock mm_key tuple (bytes, int)
# MmKey from C++ is converted to (bytes, int) tuple by pybind11
mock_hash = b'\x01\x02\x03\x04\x05\x06\x07\x08' + b'\x00' * 24 # 32 bytes
mock_offset = 42
mock_mm_key = (mock_hash, mock_offset)
result = KVCacheEventSerializer._mm_key_to_json(mock_mm_key)
assert result["type"] == "mm_key"
assert result["start_offset"] == 42
# Hash should be converted to hex string
assert result["hash"] == "0102030405060708" + "00" * 24
assert len(result["hash"]) == 64 # 32 bytes = 64 hex chars
# Test with different hash values
mock_hash2 = bytes(range(32)) # 0x00 to 0x1f
mock_mm_key2 = (mock_hash2, 100)
result2 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key2)
assert result2["type"] == "mm_key"
assert result2["start_offset"] == 100
expected_hash = ''.join(f'{i:02x}' for i in range(32))
assert result2["hash"] == expected_hash
def test_mm_keys_deserialization():
"""Test deserialization of mm_keys JSON back to 32-byte hash."""
# Test case 1: Simple hash pattern
mock_hash = b'\x01\x02\x03\x04\x05\x06\x07\x08' + b'\x00' * 24 # 32 bytes
mock_offset = 42
mock_mm_key = (mock_hash, mock_offset)
# Serialize to JSON
json_result = KVCacheEventSerializer._mm_key_to_json(mock_mm_key)
# Deserialize hex string back to bytes
recovered_hash = bytes.fromhex(json_result["hash"])
# Verify the recovered hash matches the original
assert recovered_hash == mock_hash
assert len(recovered_hash) == 32
assert json_result["start_offset"] == mock_offset
# Test case 2: Sequential bytes 0x00 to 0x1f
mock_hash2 = bytes(range(32))
mock_offset2 = 100
mock_mm_key2 = (mock_hash2, mock_offset2)
json_result2 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key2)
recovered_hash2 = bytes.fromhex(json_result2["hash"])
assert recovered_hash2 == mock_hash2
assert len(recovered_hash2) == 32
assert json_result2["start_offset"] == mock_offset2
# Test case 3: All 0xFF bytes
mock_hash3 = b'\xff' * 32
mock_offset3 = 255
mock_mm_key3 = (mock_hash3, mock_offset3)
json_result3 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key3)
recovered_hash3 = bytes.fromhex(json_result3["hash"])
assert recovered_hash3 == mock_hash3
assert len(recovered_hash3) == 32
assert json_result3["hash"] == "ff" * 32
# Test case 4: Random-like pattern
mock_hash4 = bytes([0xde, 0xad, 0xbe, 0xef] + [0xca, 0xfe] * 14)
mock_offset4 = 1024
mock_mm_key4 = (mock_hash4, mock_offset4)
json_result4 = KVCacheEventSerializer._mm_key_to_json(mock_mm_key4)
recovered_hash4 = bytes.fromhex(json_result4["hash"])
assert recovered_hash4 == mock_hash4
assert len(recovered_hash4) == 32
def test_mm_keys_in_stored_events():
"""Test that mm_keys field is present in stored block events."""
llm = create_llm()
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
prompt = "Hello, my name is"
_ = llm.generate(prompt, sampling_params=sampling_params)
events = llm.get_kv_cache_events(5)
# Find stored events and verify mm_keys field
for event in events:
if event and event["data"]["type"] == "stored":
blocks = event["data"]["blocks"]
for block in blocks:
# mm_keys should always be present (empty list for text-only)
assert "mm_keys" in block
assert isinstance(block["mm_keys"], list)
# For text-only requests, mm_keys should be empty
assert block["mm_keys"] == []
def test_expected_kv_cache_events():
llm = create_llm()
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
prompt = "Hello, my name is"
_ = llm.generate(prompt, sampling_params=sampling_params)
events = llm.get_kv_cache_events(5)
# created + stored events
assert events and len(events) >= 2
for event in events:
if event:
if event["event_id"] == 0:
assert event["data"]["type"] == "created"
elif event["event_id"] == 1:
assert event["data"]["type"] == "stored"
def test_kv_cache_event_async_api():
llm = create_llm()
sampling_params = SamplingParams(max_tokens=6, temperature=0.01)
prompt = "Hello, my name is"
async def generate():
async for output in llm.generate_async(prompt,
streaming=True,
sampling_params=sampling_params):
pass
events = []
async def get_events():
async for event in llm.get_kv_cache_events_async():
events.append(event)
assert events
async def main():
await generate()
await asyncio.gather(generate(), get_events())
await asyncio.gather(generate(), get_events())
asyncio.run(main())
def check_events(llm,
requests,
sampling_params,
scheduling_params=None,
attention_dp_rank=None):
_ = llm.generate(requests[0],
sampling_params=sampling_params,
scheduling_params=scheduling_params)
time.sleep(1)
events = llm.get_kv_cache_events(5)
# Created or stored event
total_stored_blocks = 0
if attention_dp_rank is None:
event = events.pop(0) # created event
assert event["event_id"] == 0
assert event["data"]["type"] == "created"
while events:
event = events.pop(0)
if event:
assert event["data"]["type"] == "stored"
assert event["event_id"] > 0
total_stored_blocks += len(event["data"]["blocks"])
else:
while events:
event = events.pop(0)
if not event:
continue
assert "attention_dp_rank" in event
if event["attention_dp_rank"] == attention_dp_rank:
assert event["data"]["type"] in ["created", "stored"]
if event["data"]["type"] == "created":
assert event["event_id"] == 0
if event["data"]["type"] == "stored":
assert event["event_id"] > 0
total_stored_blocks += len(event["data"]["blocks"])
assert total_stored_blocks == 5 # Should have 5 blocks in total
_ = llm.generate(requests[1],
sampling_params=sampling_params,
scheduling_params=scheduling_params)
time.sleep(1)
events2 = llm.get_kv_cache_events(5)
total_stored_blocks = 0
has_removed_event = False
while events2:
event = events2.pop(0)
if event and (attention_dp_rank is None
or event.get("attention_dp_rank") == attention_dp_rank):
if event["data"]["type"] == "removed":
has_removed_event = True
assert event["data"]["block_hashes"]
# stored events
elif event["data"]["type"] == "stored":
total_stored_blocks += len(event["data"]["blocks"])
assert total_stored_blocks == 5 # Should have 5 blocks in total
assert has_removed_event
_ = llm.generate(requests[2],
sampling_params=sampling_params,
scheduling_params=scheduling_params)
time.sleep(1)
events3 = llm.get_kv_cache_events(5)
total_stored_blocks = 0
has_removed_event = False
while events3:
event = events3.pop(0)
if event and (attention_dp_rank is None
or event.get("attention_dp_rank") == attention_dp_rank):
if event["data"]["type"] == "removed":
has_removed_event = True
assert event["data"]["block_hashes"]
elif event["data"]["type"] == "stored":
total_stored_blocks += len(event["data"]["blocks"])
assert total_stored_blocks == 5 # Should have 5 blocks in total
assert has_removed_event
# no more events after request is finished
assert not llm.get_kv_cache_events(5)
def test_llm_kv_events_api():
llm = create_llm()
sampling_params = SamplingParams(max_tokens=6,
temperature=0.01,
ignore_eos=True)
requests = []
for i in range(3):
input_tokens = list(range(127 + i))[i:]
requests.append(input_tokens)
check_events(llm, requests, sampling_params)
@skip_single_gpu
@pytest.mark.threadleak(enabled=False)
def test_llm_api_attention_dp_kv_events():
llm = LLM(model=llama_model_path,
tensor_parallel_size=2,
enable_attention_dp=True,
kv_cache_config=global_kvcache_config,
enable_autotuner=False)
sampling_params = SamplingParams(max_tokens=6,
temperature=0.01,
ignore_eos=True)
for attention_dp_rank in range(2):
requests = []
for i in range(3):
input_tokens = list(range(127 + i))[i:]
requests.append(input_tokens)
scheduling_params = SchedulingParams(
attention_dp_rank=attention_dp_rank, attention_dp_relax=False)
check_events(llm, requests, sampling_params, scheduling_params,
attention_dp_rank)