TensorRT-LLMs/tests/unittest/disaggregated/test_extractor.py
Shi Xiaowei b1268e1b37
[TRTLLM-9527][feat] Modularization of the transceiver for KV manager v2 (step 4) (#11225)
Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2026-02-06 07:15:18 -05:00

102 lines
2.8 KiB
Python

import pytest
from tensorrt_llm._torch.disaggregation.base.region import MemRegionGroup, SpecRegion
from tensorrt_llm._torch.disaggregation.resource.kv_extractor import KVRegionExtractorV1
from tensorrt_llm._torch.pyexecutor.resource_manager import (
CacheTypeCpp,
DataType,
KvCacheConfig,
KVCacheManager,
Mapping,
)
class DummyRankInfo:
instance_name = "dummy"
instance_rank = 0
tp_size = 1
tp_rank = 0
pp_size = 1
pp_rank = 0
dp_size = 1
dp_rank = 0
cp_size = 1
cp_rank = 0
device_id = 0
kv_heads_per_rank = 8
tokens_per_block = 32
dims_per_head = 16
element_bytes = 2
enable_attention_dp = False
is_mla = False
layer_num_per_pp = [1]
@property
def kv_factor(self) -> int:
return 2 if not self.is_mla else 1
@pytest.mark.cuda
def test_extract():
num_layers = 1
num_kv_heads = 8
head_dim = 16
tokens_per_block = 32
max_seq_len = 128
max_batch_size = 2
dtype = DataType.HALF
mapping = Mapping(world_size=1, rank=0, tp_size=1, pp_size=1, gpus_per_node=1)
kv_cache_config = KvCacheConfig(
max_tokens=512,
free_gpu_memory_fraction=0.1,
max_attention_window=None,
enable_block_reuse=False,
event_buffer_max_size=0,
onboard_blocks=0,
host_cache_size=0,
enable_partial_reuse=False,
copy_on_partial_reuse=False,
sink_token_length=0,
max_util_for_resume=1,
)
kv_cache_type = CacheTypeCpp.SELF
manager = KVCacheManager(
kv_cache_config=kv_cache_config,
kv_cache_type=kv_cache_type,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
mapping=mapping,
dtype=dtype,
)
extractor = KVRegionExtractorV1(manager)
region_ids = [0, 1]
spec_region = extractor.extract(region_ids)
assert isinstance(spec_region, SpecRegion)
memory = spec_region.memory
assert isinstance(memory, MemRegionGroup)
assert len(memory.ptrs) == len(region_ids)
assert memory.bytes_per_region > 0
pool_ptrs = manager.get_unique_primary_pool()
if hasattr(pool_ptrs, "__getitem__"):
if hasattr(pool_ptrs[0], "data_ptr"):
pool_base_ptr = int(pool_ptrs[0].data_ptr())
else:
pool_base_ptr = int(pool_ptrs[0])
else:
pool_base_ptr = (
int(pool_ptrs.data_ptr()) if hasattr(pool_ptrs, "data_ptr") else int(pool_ptrs)
)
expected_block_bytes = memory.bytes_per_region
expected_ptrs = [pool_base_ptr + block_id * expected_block_bytes for block_id in region_ids]
assert list(memory.ptrs) == expected_ptrs
manager.shutdown()