mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
102 lines
2.8 KiB
Python
102 lines
2.8 KiB
Python
import pytest
|
|
|
|
from tensorrt_llm._torch.disaggregation.base.region import MemRegionGroup, SpecRegion
|
|
from tensorrt_llm._torch.disaggregation.resource.kv_extractor import KVRegionExtractorV1
|
|
from tensorrt_llm._torch.pyexecutor.resource_manager import (
|
|
CacheTypeCpp,
|
|
DataType,
|
|
KvCacheConfig,
|
|
KVCacheManager,
|
|
Mapping,
|
|
)
|
|
|
|
|
|
class DummyRankInfo:
|
|
instance_name = "dummy"
|
|
instance_rank = 0
|
|
tp_size = 1
|
|
tp_rank = 0
|
|
pp_size = 1
|
|
pp_rank = 0
|
|
dp_size = 1
|
|
dp_rank = 0
|
|
cp_size = 1
|
|
cp_rank = 0
|
|
device_id = 0
|
|
kv_heads_per_rank = 8
|
|
tokens_per_block = 32
|
|
dims_per_head = 16
|
|
element_bytes = 2
|
|
enable_attention_dp = False
|
|
is_mla = False
|
|
layer_num_per_pp = [1]
|
|
|
|
@property
|
|
def kv_factor(self) -> int:
|
|
return 2 if not self.is_mla else 1
|
|
|
|
|
|
@pytest.mark.cuda
|
|
def test_extract():
|
|
num_layers = 1
|
|
num_kv_heads = 8
|
|
head_dim = 16
|
|
tokens_per_block = 32
|
|
max_seq_len = 128
|
|
max_batch_size = 2
|
|
dtype = DataType.HALF
|
|
mapping = Mapping(world_size=1, rank=0, tp_size=1, pp_size=1, gpus_per_node=1)
|
|
kv_cache_config = KvCacheConfig(
|
|
max_tokens=512,
|
|
free_gpu_memory_fraction=0.1,
|
|
max_attention_window=None,
|
|
enable_block_reuse=False,
|
|
event_buffer_max_size=0,
|
|
onboard_blocks=0,
|
|
host_cache_size=0,
|
|
enable_partial_reuse=False,
|
|
copy_on_partial_reuse=False,
|
|
sink_token_length=0,
|
|
max_util_for_resume=1,
|
|
)
|
|
kv_cache_type = CacheTypeCpp.SELF
|
|
|
|
manager = KVCacheManager(
|
|
kv_cache_config=kv_cache_config,
|
|
kv_cache_type=kv_cache_type,
|
|
num_layers=num_layers,
|
|
num_kv_heads=num_kv_heads,
|
|
head_dim=head_dim,
|
|
tokens_per_block=tokens_per_block,
|
|
max_seq_len=max_seq_len,
|
|
max_batch_size=max_batch_size,
|
|
mapping=mapping,
|
|
dtype=dtype,
|
|
)
|
|
|
|
extractor = KVRegionExtractorV1(manager)
|
|
region_ids = [0, 1]
|
|
spec_region = extractor.extract(region_ids)
|
|
|
|
assert isinstance(spec_region, SpecRegion)
|
|
memory = spec_region.memory
|
|
assert isinstance(memory, MemRegionGroup)
|
|
assert len(memory.ptrs) == len(region_ids)
|
|
assert memory.bytes_per_region > 0
|
|
|
|
pool_ptrs = manager.get_unique_primary_pool()
|
|
if hasattr(pool_ptrs, "__getitem__"):
|
|
if hasattr(pool_ptrs[0], "data_ptr"):
|
|
pool_base_ptr = int(pool_ptrs[0].data_ptr())
|
|
else:
|
|
pool_base_ptr = int(pool_ptrs[0])
|
|
else:
|
|
pool_base_ptr = (
|
|
int(pool_ptrs.data_ptr()) if hasattr(pool_ptrs, "data_ptr") else int(pool_ptrs)
|
|
)
|
|
expected_block_bytes = memory.bytes_per_region
|
|
expected_ptrs = [pool_base_ptr + block_id * expected_block_bytes for block_id in region_ids]
|
|
assert list(memory.ptrs) == expected_ptrs
|
|
|
|
manager.shutdown()
|