From b1268e1b37ae20b6119242617558c7dea1fa5830 Mon Sep 17 00:00:00 2001 From: Shi Xiaowei <39303645+Shixiaowei02@users.noreply.github.com> Date: Fri, 6 Feb 2026 20:15:18 +0800 Subject: [PATCH] [TRTLLM-9527][feat] Modularization of the transceiver for KV manager v2 (step 4) (#11225) Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> --- .../_torch/disaggregation/base/region.py | 124 ++++++++ .../_torch/disaggregation/native/peer.py | 286 +++++++++++++++++ .../_torch/disaggregation/native/rank_info.py | 75 +++++ .../disaggregation/native/region/__init__.py | 0 .../disaggregation/native/region/aux.py | 183 +++++++++++ .../disaggregation/native/region/block.py | 216 +++++++++++++ .../disaggregation/native/region/page.py | 126 ++++++++ .../disaggregation/resource/__init__.py | 0 .../disaggregation/resource/kv_extractor.py | 89 ++++++ .../resource/kv_extractor_v2.py | 77 +++++ .../integration/test_lists/test-db/l0_a10.yml | 4 + .../test_lists/test-db/l0_h100.yml | 4 + .../disaggregated/region/test_block.py | 146 +++++++++ .../unittest/disaggregated/test_extractor.py | 101 ++++++ .../disaggregated/test_extractor_v2.py | 78 +++++ tests/unittest/disaggregated/test_peer.py | 297 ++++++++++++++++++ 16 files changed, 1806 insertions(+) create mode 100644 tensorrt_llm/_torch/disaggregation/base/region.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/peer.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/rank_info.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/region/__init__.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/region/aux.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/region/block.py create mode 100644 tensorrt_llm/_torch/disaggregation/native/region/page.py create mode 100644 tensorrt_llm/_torch/disaggregation/resource/__init__.py create mode 100644 tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py create mode 100644 tensorrt_llm/_torch/disaggregation/resource/kv_extractor_v2.py create mode 100644 tests/unittest/disaggregated/region/test_block.py create mode 100644 tests/unittest/disaggregated/test_extractor.py create mode 100644 tests/unittest/disaggregated/test_extractor_v2.py create mode 100644 tests/unittest/disaggregated/test_peer.py diff --git a/tensorrt_llm/_torch/disaggregation/base/region.py b/tensorrt_llm/_torch/disaggregation/base/region.py new file mode 100644 index 0000000000..0e638c4a1a --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/base/region.py @@ -0,0 +1,124 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import IntFlag, auto +from typing import List, NamedTuple, Optional + + +@dataclass(frozen=True) +class IndexRange: + """ + Represents a closed interval [start, end], with both bounds >= 0. + Commonly used for indexing layers, heads, or tokens. + """ + + start: int + end: int + + def __post_init__(self): + if not (isinstance(self.start, int) and isinstance(self.end, int)): + raise TypeError("start and end must be integers") + if self.start < 0 or self.end < 0: + raise ValueError("start and end must be >= 0") + if self.end < self.start: + raise ValueError("end must be >= start") + + +class MemRegion(NamedTuple): + """Describes a block of memory by starting pointer and size in bytes.""" + + ptr: int + bytes: int + + +class MemRegionGroup(NamedTuple): + """Describes a block of memory by starting pointer and size in bytes.""" + + ptrs: List[int] + bytes_per_region: int + + +class DataRole(IntFlag): + """Logical role(s) a memory region plays. Supports combinations.""" + + KEY = auto() + VALUE = auto() + + +class DataLayout(IntFlag): + """Possible orders for storing data in memory.""" + + HND = auto() # (head, seq_len, dim) + NHD = auto() # (seq_len, head, dim) + + +@dataclass(frozen=True) +class RegionSpec: + """ + Specifies a (potentially partial) region of the cache. + Extend this base class for additional axis or specialization. + """ + + layers: Optional[IndexRange] = None + + +@dataclass(frozen=True) +class KVRegionSpec(RegionSpec): + """ + Specifies a region within the Key/Value cache, with optional axes. + """ + + role: DataRole = DataRole.KEY | DataRole.VALUE + heads: Optional[IndexRange] = None + tokens: Optional[IndexRange] = None + + +class SpecRegion(NamedTuple): + """ + Associates a memory region with its semantic specifier. + """ + + memory: MemRegion | MemRegionGroup + spec: RegionSpec = None + + +class RegionExtractorBase(ABC): + """ + Interface for extracting region descriptors from some backing store. + """ + + @abstractmethod + def extract(self, region_ids: Optional[List[int]] = None) -> List[SpecRegion]: + """ + Args: + region_ids: (Optional) List of integer region identifiers to extract. + Returns: + List of Regions for corresponding regions. + """ + ... + + +class SpecRegionPair(NamedTuple): + """ + Maps a source descriptor to a destination descriptor + (e.g., when copying or reindexing regions). + """ + + src: SpecRegion + dst: SpecRegion + + +class RegionMapperBase(ABC): + """ + Maps a batch of region descriptors to corresponding destination(s). + """ + + @abstractmethod + def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair: + """ + Args: + src_regions: List of source Regions. + dst_regions: List of destination Regions. + Returns: + List of RegionPairs mapping source to destination. + """ + ... diff --git a/tensorrt_llm/_torch/disaggregation/native/peer.py b/tensorrt_llm/_torch/disaggregation/native/peer.py new file mode 100644 index 0000000000..fcbfc7da4e --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/peer.py @@ -0,0 +1,286 @@ +from dataclasses import dataclass, field +from typing import Dict, List + +from tensorrt_llm import logger +from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo +from tensorrt_llm._torch.disaggregation.native.region.block import ( + HeadMatchMapper, + HeadMismatchMapper, + IdentityMapper, + RegionMapperBase, +) +from tensorrt_llm._torch.disaggregation.resource.kv_extractor import ( + KVPoolAttrs, + KVRegionExtractorV1, +) + + +@dataclass +class PeerOverlap: + overlap_pp_size: int = 0 + overlap_tp_size: int = 0 + overlap_cp_size: int = 0 + duplicate_head_factor: int = 1 + peer_duplicate_head_factor: int = 1 + target_peer_pp_layer_num: List[int] = field(default_factory=list) + ranks: List[int] = field(default_factory=list) + + +class PeerRegistrar: + def __init__(self, self_rank_info: RankInfo, self_extractor: KVRegionExtractorV1): + self._ri = self_rank_info + self._peer_ri_cache: Dict[str, RankInfo] = {} + self._kv_map_cache: Dict[str, RegionMapperBase] = {} + self._self_ext_cache = self_extractor + self._peer_ext_cache: Dict[str, KVRegionExtractorV1] = {} + self._overlap_cache: Dict[str, PeerOverlap] = {} + + def _block_size(self, layer_num: int, ri: RankInfo) -> int: + return ( + layer_num + * ri.kv_factor + * ri.kv_heads_per_rank + * ri.tokens_per_block + * ri.dims_per_head + * ri.element_bytes + ) + + def register(self, peer_name: str, peer_rank: int, peer_ri: RankInfo): + # TODO: check if peer is valid for registration + assert self._self_ext_cache is not None + if not self._check_peer_compatible(peer_ri): + raise ValueError( + f"PeerRegistrar.register: peer {peer_name} (rank={peer_rank}) is incompatible with local rank." + ) + key = self._unique_key(peer_name, peer_rank) + self._peer_ri_cache[key] = peer_ri + peer_ri = self.get_peer_rank_info(peer_name, peer_rank) + layer_num = peer_ri.layer_num_per_pp[peer_ri.pp_rank] + block_size = self._block_size(layer_num, peer_ri) + extractor = KVRegionExtractorV1( + KVPoolAttrs(pool_ptrs=peer_ri.kv_ptrs, block_bytes=[block_size]) + ) + self._peer_ext_cache[key] = extractor + + def peer_extractor(self, peer_name: str, peer_rank: int) -> KVRegionExtractorV1: + return self._peer_ext_cache[self._unique_key(peer_name, peer_rank)] + + @property + def self_extractor(self) -> KVRegionExtractorV1: + assert self._self_ext_cache is not None + return self._self_ext_cache + + def unregister(self, peer_name: str, peer_rank: int): + key = self._unique_key(peer_name, peer_rank) + if key in self._peer_ri_cache: + del self._peer_ri_cache[key] + if key in self._peer_ext_cache: + del self._peer_ext_cache[key] + if key in self._kv_map_cache: + del self._kv_map_cache[key] + + def get_peer_rank_info(self, peer_name: str, peer_rank: int): + return self._peer_ri_cache[self._unique_key(peer_name, peer_rank)] + + @property + def self_rank_info(self) -> RankInfo: + return self._ri + + def _unique_key(self, name: str, rank: int) -> str: + return name + str(rank) + + def _check_peer_compatible(self, peer_ri: RankInfo) -> bool: + if self._ri.is_mla != peer_ri.is_mla: + logger.warning( + "PeerRegistrar: compatibility check failed: 'is_mla' differs " + f"(local={self._ri.is_mla}, peer={peer_ri.is_mla})." + ) + return False + if self._ri.cp_size != 1 or peer_ri.cp_size != 1: + logger.warning( + "PeerRegistrar: unsupported configuration: context parallelism (cp_size) " + f"must be 1 for both local and peer ranks (local={self._ri.cp_size}, peer={peer_ri.cp_size})." + ) + return False + if self._ri.element_bytes != peer_ri.element_bytes: + logger.warning( + "PeerRegistrar: element size mismatch " + f"(local={self._ri.element_bytes} bytes, peer={peer_ri.element_bytes} bytes)." + ) + return False + if self._ri.tokens_per_block != peer_ri.tokens_per_block: + logger.warning( + "PeerRegistrar: tokens_per_block mismatch " + f"(local={self._ri.tokens_per_block}, peer={peer_ri.tokens_per_block})." + ) + return False + if self._ri.dims_per_head != peer_ri.dims_per_head: + logger.warning( + "PeerRegistrar: dims_per_head mismatch " + f"(local={self._ri.dims_per_head}, peer={peer_ri.dims_per_head})." + ) + return False + + self_layers = sum(self._ri.layer_num_per_pp) + peer_layers = sum(peer_ri.layer_num_per_pp) + if self_layers != peer_layers: + logger.warning( + "PeerRegistrar: total layer count mismatch " + f"(local={self_layers}, peer={peer_layers})." + ) + return False + + if self._ri.is_mla: + if peer_ri.kv_heads_per_rank != 1 or self._ri.kv_heads_per_rank != 1: + logger.warning( + "PeerRegistrar: MLA mode requires exactly 1 KV head per rank for both local and peer." + f" (local={self._ri.kv_heads_per_rank}, peer={peer_ri.kv_heads_per_rank})" + ) + return False + return True + + def _tp_per_dp(self, info: RankInfo) -> int: + return ( + info.tp_size // info.dp_size + if getattr(info, "enable_attention_dp", False) + else info.tp_size + ) + + def get_kv_map(self, peer_ri: RankInfo): + key = self._unique_key(peer_ri.instance_name, peer_ri.instance_rank) + if key in self._kv_map_cache: + return self._kv_map_cache[key] + + self_tp_per_dp = self._tp_per_dp(self._ri) + peer_tp_per_dp = self._tp_per_dp(peer_ri) + + is_dup_head = ( + self._ri.kv_heads_per_rank * self_tp_per_dp + != peer_ri.kv_heads_per_rank * peer_tp_per_dp + ) + head_match = is_dup_head or self._ri.is_mla or self_tp_per_dp == peer_tp_per_dp + logger.debug( + "KVMapperFactory.get_kv_map: " + f"head_match={head_match}, is_dup_head={is_dup_head}, self_is_mla={self._ri.is_mla}, " + f"self_tp_per_dp={self_tp_per_dp}, peer_tp_per_dp={peer_tp_per_dp}" + ) + # fast identity when write_all and same pp_size + if head_match and self._ri.pp_size == peer_ri.pp_size: + mapper = IdentityMapper() + self._kv_map_cache[key] = mapper + return mapper + + # compute overlapping layers + self_start_layer = sum(self._ri.layer_num_per_pp[: self._ri.pp_rank]) + self_end_layer = self_start_layer + self._ri.layer_num_per_pp[self._ri.pp_rank] + peer_start_layer = sum(peer_ri.layer_num_per_pp[: peer_ri.pp_rank]) + peer_end_layer = peer_start_layer + peer_ri.layer_num_per_pp[peer_ri.pp_rank] + start = max(self_start_layer, peer_start_layer) + end = min(self_end_layer, peer_end_layer) + transfer_layers = end - start + self_layer_offset = start - self_start_layer + peer_layer_offset = start - peer_start_layer + + if head_match: + mapper = HeadMatchMapper( + transfer_layers=transfer_layers, + src_layer_off=self_layer_offset, # local layer offset + dst_layer_off=peer_layer_offset, # peer layer offset + self_ri=self._ri, + peer_ri=peer_ri, + ) + self._kv_map_cache[key] = mapper + return mapper + + # head mismatch case + mapper = HeadMismatchMapper( + transfer_layers=transfer_layers, + src_layer_off=self_layer_offset, + peer_layer_off=peer_layer_offset, + self_ri=self._ri, + peer_ri=peer_ri, + ) + self._kv_map_cache[key] = mapper + return mapper + + @staticmethod + def _find_overlap(self_val, peer_val, self_rank, peer_rank=None): + if self_val <= peer_val: + overlap = peer_val // self_val + start = self_rank * overlap + (peer_rank * peer_val if peer_rank is not None else 0) + end = start + overlap + else: + ratio = self_val // peer_val + start = (self_rank // ratio) + (peer_rank * peer_val if peer_rank is not None else 0) + overlap = 1 + end = start + overlap + + return overlap, start, end + + def get_peer_overlap(self, peer_rank_info: RankInfo, peer_dp_rank: int) -> PeerOverlap: + peer_ri = peer_rank_info + key = self._unique_key(peer_ri.instance_name, peer_dp_rank) + if key in self._overlap_cache: + return self._overlap_cache[key] + + # compute pp overlap and target layers + self_start_layer = sum(self._ri.layer_num_per_pp[: self._ri.pp_rank]) + self_end_layer = self_start_layer + self._ri.layer_num_per_pp[self._ri.pp_rank] + + pre = 0 + tgt_pp_ranks: List[int] = [] + tgt_pp_layer_num: List[int] = [] + for p in range(peer_ri.pp_size): + peer_start_layer = pre + peer_end_layer = peer_start_layer + peer_ri.layer_num_per_pp[p] + if self_start_layer < peer_end_layer and self_end_layer > peer_start_layer: + tgt_pp_ranks.append(p) + tgt_pp_layer_num.append( + min(peer_end_layer, self_end_layer) - max(peer_start_layer, self_start_layer) + ) + pre += peer_ri.layer_num_per_pp[p] + + if tgt_pp_ranks == []: + # no overlap found + targets = PeerOverlap() + self._overlap_cache[key] = targets + return targets + + peer_start_pp = tgt_pp_ranks[0] + overlap_pp_size = len(tgt_pp_ranks) + peer_end_pp = peer_start_pp + overlap_pp_size + + # tp per dp-group + self_tp_per_dp = self._tp_per_dp(self._ri) + peer_tp_per_dp = self._tp_per_dp(peer_ri) + self_tp_rank_in_dp = self._ri.tp_rank % self_tp_per_dp + + overlap_tp_size, peer_start_tp, peer_end_tp = self._find_overlap( + self_tp_per_dp, peer_tp_per_dp, self_tp_rank_in_dp, peer_dp_rank + ) + overlap_cp_size, peer_start_cp, peer_end_cp = self._find_overlap( + self._ri.cp_size, peer_ri.cp_size, self._ri.cp_rank + ) + + ranks: List[int] = [] + for pp in range(peer_start_pp, peer_end_pp): + for cp in range(peer_start_cp, peer_end_cp): + for tp in range(peer_start_tp, peer_end_tp): + ranks.append(pp * peer_ri.tp_size * peer_ri.cp_size + cp * peer_ri.tp_size + tp) + + factor_self = self._ri.kv_heads_per_rank * self_tp_per_dp + factor_peer = peer_ri.kv_heads_per_rank * peer_tp_per_dp + dup_head = max(1, factor_self // factor_peer) + peer_dup_head = max(1, factor_peer // factor_self) + + targets = PeerOverlap( + overlap_pp_size=overlap_pp_size, + overlap_tp_size=overlap_tp_size, + overlap_cp_size=overlap_cp_size, + duplicate_head_factor=dup_head, + peer_duplicate_head_factor=peer_dup_head, + target_peer_pp_layer_num=tgt_pp_layer_num, + ranks=ranks, + ) + self._overlap_cache[key] = targets + return targets diff --git a/tensorrt_llm/_torch/disaggregation/native/rank_info.py b/tensorrt_llm/_torch/disaggregation/native/rank_info.py new file mode 100644 index 0000000000..007e819b8f --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/rank_info.py @@ -0,0 +1,75 @@ +from dataclasses import asdict, dataclass +from typing import List, Optional + +import msgpack + +from tensorrt_llm._torch.disaggregation.native.region.aux import AuxBufferMeta + + +@dataclass +class InstanceInfo: + instance_name: str + tp_size: int + pp_size: int + dp_size: int + cp_size: int + kv_heads_per_rank: int + tokens_per_block: int + dims_per_head: int + element_bytes: int + enable_attention_dp: bool + is_mla: bool + layer_num_per_pp: List[int] + sender_endpoints: List[str] + + def to_bytes(self) -> bytes: + return msgpack.packb(asdict(self)) + + @classmethod + def from_bytes(cls, data: bytes) -> "InstanceInfo": + return cls(**msgpack.unpackb(data)) + + +@dataclass +class RankInfo: + instance_name: str + instance_rank: int + tp_size: int + tp_rank: int + pp_size: int + pp_rank: int + dp_size: int + dp_rank: int + cp_size: int + cp_rank: int + device_id: int + kv_heads_per_rank: int + # [numLayers, kv_factor, heads, tokens, dims_per_head] + tokens_per_block: int + dims_per_head: int + element_bytes: int + enable_attention_dp: bool + is_mla: bool + layer_num_per_pp: List[int] + kv_ptrs: List[int] + aux_ptrs: List[int] + server_endpoint: str + self_endpoint: str + transfer_engine_info: bytes + aux_meta: Optional[AuxBufferMeta] + + @property + def kv_factor(self) -> int: + return 2 if not self.is_mla else 1 + + def to_bytes(self) -> bytes: + data = asdict(self) + data["aux_meta"] = self.aux_meta.to_dict() if self.aux_meta is not None else None + return msgpack.packb(data) + + @classmethod + def from_bytes(cls, data: bytes) -> "RankInfo": + unpacked = msgpack.unpackb(data) + if unpacked.get("aux_meta") is not None: + unpacked["aux_meta"] = AuxBufferMeta.from_dict(unpacked["aux_meta"]) + return cls(**unpacked) diff --git a/tensorrt_llm/_torch/disaggregation/native/region/__init__.py b/tensorrt_llm/_torch/disaggregation/native/region/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/disaggregation/native/region/aux.py b/tensorrt_llm/_torch/disaggregation/native/region/aux.py new file mode 100644 index 0000000000..8fb4a1d209 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/region/aux.py @@ -0,0 +1,183 @@ +from abc import ABC, abstractmethod +from collections import deque +from dataclasses import dataclass, field +from typing import Any + +import torch + +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest + + +@dataclass +class AuxBufferMeta: + ptrs: list[int] + size: list[int] + item_sizes: list[int] = field(default_factory=list) + device: str = "cpu" + + def to_dict(self) -> dict[str, Any]: + return { + "ptrs": self.ptrs, + "size": self.size, + "item_sizes": self.item_sizes, + "device": self.device, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "AuxBufferMeta": + return cls( + ptrs=data["ptrs"], + size=data["size"], + item_sizes=data.get("item_sizes", []), + device=data.get("device", "cpu"), + ) + + +class AuxBufferBase(ABC): + """ + Abstract base class defining the interface for auxiliary buffer management. + """ + + @abstractmethod + def alloc_slot(self) -> int: + """ + Allocate a free slot and return its index. + """ + ... + + @abstractmethod + def free_slot(self, slot: int) -> None: + """ + Release the specified slot. + """ + ... + + @property + @abstractmethod + def meta(self) -> AuxBufferMeta: + """ + Retrieve meta-information about the underlying buffer(s). + Returns buffer info (e.g., pointers, sizes, device). + """ + ... + + @abstractmethod + def fill_slot(self, slot: int, request: LlmRequest) -> None: + """ + Fill/overwrite the contents of the given slot with data from the request. + """ + ... + + @abstractmethod + def get_slot_tokens(self, slot: int) -> tuple[list[int], list[int]]: + """ + Get the token data (e.g., first/draft tokens) from the specified slot. + """ + ... + + +class AuxBuffer(AuxBufferBase): + def __init__(self, max_slot_num: int, beam_width: int, max_draft_len: int, device: str = "cpu"): + # public constructor args remain the same, internals are private + self._max_slot_num = int(max_slot_num) + self._beam_width = int(beam_width) + self._max_draft_len = int(max_draft_len) + self._device = device + + self._free_slots = deque(list(range(self._max_slot_num))) + self._occupied_slots: set[int] = set() + self._slot_token_counts: dict[ + int, tuple[int, int] + ] = {} # slot -> (first_tokens_len, draft_tokens_len) + + data_type = torch.int32 + self._first_tokens_buffer = torch.empty( + self._max_slot_num, self._beam_width, dtype=data_type, device=self._device + ) + + self._draft_tokens_buffer = torch.empty( + self._max_slot_num, self._max_draft_len, dtype=data_type, device=self._device + ) + + self._meta = AuxBufferMeta( + ptrs=[self._first_tokens_buffer.data_ptr(), self._draft_tokens_buffer.data_ptr()], + size=[ + self._first_tokens_buffer.numel() * self._first_tokens_buffer.element_size(), + self._draft_tokens_buffer.numel() * self._draft_tokens_buffer.element_size(), + ], + item_sizes=[ + self._first_tokens_buffer[0].numel() * self._first_tokens_buffer.element_size(), + self._draft_tokens_buffer[0].numel() * self._draft_tokens_buffer.element_size(), + ], + device=self._device, + ) + + def alloc_slot(self) -> int: + if not self._free_slots: + raise ValueError( + f"No free auxiliary buffer slots available (max slots = {self._max_slot_num}). " + "All slots are currently occupied." + ) + slot_id = self._free_slots.popleft() + if slot_id in self._occupied_slots: + # This should not happen — defensive check. + raise RuntimeError( + f"Invariant error: selected slot {slot_id} is already marked as occupied. " + "This indicates a bug in slot management." + ) + self._occupied_slots.add(slot_id) + return slot_id + + def free_slot(self, slot: int) -> None: + if slot not in self._occupied_slots: + raise ValueError( + f"Attempted to free slot {slot}, but that slot is not currently allocated. " + "Ensure `alloc_slot` was called and the slot wasn't freed already." + ) + if slot < 0 or slot >= self._max_slot_num: + raise ValueError( + f"Invalid slot id {slot}. Valid slot indices are in the range 0..{self._max_slot_num - 1}." + ) + self._occupied_slots.remove(slot) + self._free_slots.append(slot) + + @property + def meta(self) -> AuxBufferMeta: + return self._meta + + def fill_slot(self, slot: int, request: LlmRequest) -> None: + if slot not in self._occupied_slots: + raise ValueError( + f"Cannot fill slot {slot}: slot is not currently allocated. " + "Call `alloc_slot` first." + ) + first_gen_tokens = request.get_last_tokens() + draft_tokens = request.py_draft_tokens + + if len(first_gen_tokens) > self._beam_width: + raise ValueError( + f"`first_gen_tokens` length ({len(first_gen_tokens)}) exceeds `beam_width` ({self._beam_width}). " + "Consider truncating the token list or increasing the beam_width when creating the `AuxBuffer`." + ) + if len(draft_tokens) > self._max_draft_len: + raise ValueError( + f"`draft_tokens` length ({len(draft_tokens)}) exceeds `max_draft_len` ({self._max_draft_len}). " + "Consider truncating draft tokens or increasing `max_draft_len` when creating the `AuxBuffer`." + ) + + self._first_tokens_buffer[slot][: len(first_gen_tokens)].copy_( + torch.tensor(first_gen_tokens, dtype=torch.int32, device=self._device) + ) + self._draft_tokens_buffer[slot][: len(draft_tokens)].copy_( + torch.tensor(draft_tokens, dtype=torch.int32, device=self._device) + ) + self._slot_token_counts[slot] = (len(first_gen_tokens), len(draft_tokens)) + + def get_slot_tokens(self, slot: int) -> tuple[list[int], list[int]]: + first_len, draft_len = self._slot_token_counts.get( + slot, (self._beam_width, self._max_draft_len) + ) + first_gen_tokens = self._first_tokens_buffer[slot][:first_len].tolist() + draft_tokens = self._draft_tokens_buffer[slot][:draft_len].tolist() + + return first_gen_tokens, draft_tokens diff --git a/tensorrt_llm/_torch/disaggregation/native/region/block.py b/tensorrt_llm/_torch/disaggregation/native/region/block.py new file mode 100644 index 0000000000..e08f01a1ad --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/region/block.py @@ -0,0 +1,216 @@ +import numpy as np + +from tensorrt_llm._torch.disaggregation.base.region import ( + MemRegionGroup, + RegionMapperBase, + SpecRegion, + SpecRegionPair, +) +from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo + + +class IdentityMapper(RegionMapperBase): + """ + ---- mapper_identity ---- + + Pass-through mapping. Do not change pointers or sizes. + + src_ptrs: [ S0 ] [ S1 ] [ S2 ] ... + | | | + v v v + dst_ptrs: [ D0 ] [ D1 ] [ D2 ] ... + """ + + def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair: + src_group = src_regions.memory + dst_group = dst_regions.memory + assert len(src_group.ptrs) == len(dst_group.ptrs), ( + f"Number of regions of src({len(src_group.ptrs)}) and dst({len(dst_group.ptrs)}) must match" + ) + new_src = MemRegionGroup( + ptrs=list(src_group.ptrs), bytes_per_region=src_group.bytes_per_region + ) + new_dst = MemRegionGroup( + ptrs=list(dst_group.ptrs), bytes_per_region=dst_group.bytes_per_region + ) + return SpecRegionPair( + src=SpecRegion(memory=new_src, spec=src_regions.spec), + dst=SpecRegion(memory=new_dst, spec=dst_regions.spec), + ) + + +class HeadMatchMapper(RegionMapperBase): + """ + ---- mapper_head_match ---- + + Move/copy entire contiguous block(s) (multi-layer fragment) as a single chunk. + Align by whole fragment size (frag_size) and apply a constant source/destination block offset. + + src_ptrs: [ S0 ] [ S1 ] ... + | | + + src_off + src_off + | | + [ S0 + src_off ] [ S1 + src_off ] -> (each points to a frag of size frag_size) + copy whole frag + | | + v v + [ D0 + dst_off ] [ D1 + dst_off ] -> (destination frags) + """ + + def __init__( + self, + transfer_layers: int, + src_layer_off: int, + dst_layer_off: int, + self_ri: RankInfo, + peer_ri: RankInfo, + ): + self._kv_factor = self_ri.kv_factor + self._frag_size = self._block_size(transfer_layers, self_ri) + self._src_block_off = self._block_size(src_layer_off, self_ri) + self._dst_block_off = self._block_size(dst_layer_off, peer_ri) + + def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair: + src_group = src_regions.memory + dst_group = dst_regions.memory + assert len(src_group.ptrs) == len(dst_group.ptrs), ( + f"Number of regions of src({len(src_group.ptrs)}) and dst({len(dst_group.ptrs)}) must match" + ) + new_src_ptrs = [src_ptr + self._src_block_off for src_ptr in src_group.ptrs] + new_dst_ptrs = [dst_ptr + self._dst_block_off for dst_ptr in dst_group.ptrs] + new_src = MemRegionGroup(ptrs=new_src_ptrs, bytes_per_region=self._frag_size) + new_dst = MemRegionGroup(ptrs=new_dst_ptrs, bytes_per_region=self._frag_size) + return SpecRegionPair( + src=SpecRegion(memory=new_src, spec=src_regions.spec), + dst=SpecRegion(memory=new_dst, spec=dst_regions.spec), + ) + + def _block_size(self, layer_num: int, ri: RankInfo) -> int: + return ( + layer_num + * ri.kv_factor + * ri.kv_heads_per_rank + * ri.tokens_per_block + * ri.dims_per_head + * ri.element_bytes + ) + + +class HeadMismatchMapper(RegionMapperBase): + """ + ---- mapper_head_mismatch ---- + + Fine-grained mapping when head counts or TP/DP partitioning differ. + Split layers into per-head (or contiguous-heads) fragments and map them individually. + Handles kv_factor (e.g., key+value duplication) and TP/DP head offsets. + + Source (layers x heads): + L0: [S00 S01] [S02 S03] ... + L1: [S10 S11] [S12 S13] ... + + Destination (layers x heads, different layout possible): + L0': [D00] [D01] [D02] ... + L1': [D10] [D11] ... + + Mapping (each arrow = copy cont_heads_frag): + [S00 S01] -> [D00] + [S02 S03] -> [D01] + [S10 S11] -> [D02] + """ + + def __init__( + self, + transfer_layers: int, + src_layer_off: int, + peer_layer_off: int, + self_ri: RankInfo, + peer_ri: RankInfo, + ): + self._ri = self_ri + self._peer_ri = peer_ri + self._src_layer_off = src_layer_off + + kv_factor = self_ri.kv_factor + self_tp_per_dp = self_ri.tp_size // self_ri.dp_size + peer_tp_per_dp = peer_ri.tp_size // peer_ri.dp_size + self_tp_rank = self_ri.tp_rank + peer_tp_rank = peer_ri.tp_rank + + bytes_per_head = self._ri.tokens_per_block * self._ri.dims_per_head * self._ri.element_bytes + self._bytes_cont_heads = ( + min(self._ri.kv_heads_per_rank, peer_ri.kv_heads_per_rank) * bytes_per_head + ) + + self._src_head_off, self._dst_head_off = self._compute_head_offsets( + self_tp_per_dp, + peer_tp_per_dp, + self_tp_rank, + peer_tp_rank, + self._bytes_cont_heads, + ) + self._layer_indices = np.arange(transfer_layers, dtype=np.int64) + self._kv_indices = np.arange(kv_factor, dtype=np.int64) + self._peer_layer_off = peer_layer_off + + def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair: + src_group = src_regions.memory + dst_group = dst_regions.memory + assert len(src_group.ptrs) == len(dst_group.ptrs), ( + f"Number of regions of src({len(src_group.ptrs)}) and dst({len(dst_group.ptrs)}) must match" + ) + src_bases = np.array(src_group.ptrs, dtype=np.int64) + dst_bases = np.array(dst_group.ptrs, dtype=np.int64) + src_frags = self._get_frags( + bases=src_bases, + layer_indices=self._src_layer_off + self._layer_indices, + layer_kv_num=self._get_layer_kv_num(self._ri), + kv_indices=self._kv_indices, + head_off=self._src_head_off, + kv_factor=self._kv_indices.size, + ) + dst_frags = self._get_frags( + bases=dst_bases, + layer_indices=self._peer_layer_off + self._layer_indices, + layer_kv_num=self._get_layer_kv_num(self._peer_ri), + kv_indices=self._kv_indices, + head_off=self._dst_head_off, + kv_factor=self._kv_indices.size, + ) + all_src_ptrs = [int(x) for x in src_frags.flatten()] + all_dst_ptrs = [int(x) for x in dst_frags.flatten()] + new_src = MemRegionGroup(ptrs=all_src_ptrs, bytes_per_region=self._bytes_cont_heads) + new_dst = MemRegionGroup(ptrs=all_dst_ptrs, bytes_per_region=self._bytes_cont_heads) + return SpecRegionPair( + src=SpecRegion(memory=new_src, spec=src_regions.spec), + dst=SpecRegion(memory=new_dst, spec=dst_regions.spec), + ) + + @staticmethod + def _compute_head_offsets( + self_tp_per_dp: int, + peer_tp_per_dp: int, + self_tp_rank: int, + peer_tp_rank: int, + bytes_cont_heads: int, + ) -> tuple[int, int]: + if self_tp_per_dp == peer_tp_per_dp: + return 0, 0 + ratio = max(self_tp_per_dp, peer_tp_per_dp) // min(self_tp_per_dp, peer_tp_per_dp) + if self_tp_per_dp < peer_tp_per_dp: + return (peer_tp_rank % ratio) * bytes_cont_heads, 0 + else: + return 0, (self_tp_rank % ratio) * bytes_cont_heads + + @staticmethod + def _get_layer_kv_num(ri: RankInfo) -> int: + return ri.kv_heads_per_rank * ri.tokens_per_block * ri.dims_per_head * ri.element_bytes + + @staticmethod + def _get_frags(bases, layer_indices, layer_kv_num, kv_indices, head_off, kv_factor): + layer_num = layer_kv_num * kv_factor + return ( + bases[:, None, None] + + layer_num * layer_indices[None, :, None] + + layer_kv_num * kv_indices[None, None, :] + + head_off + ) diff --git a/tensorrt_llm/_torch/disaggregation/native/region/page.py b/tensorrt_llm/_torch/disaggregation/native/region/page.py new file mode 100644 index 0000000000..217871e621 --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/native/region/page.py @@ -0,0 +1,126 @@ +from dataclasses import dataclass +from typing import List, Set + +import numpy as np + +BUFFER_ENTRY_DTYPE = np.dtype( + [ + ("layer_id", np.uint32), + ("role", np.uint32), + ("offset", np.uint32), + ("size", np.uint32), + ] +) + + +@dataclass +class PoolDescriptor: + """ + Pool descriptor containing memory layout and buffer information + + One pool contains multiple buffer entries, each representing a (layer_id, role) combination. + """ + + base_address: int # (uint64) + slot_bytes: int + num_slots: int + + # Buffer entries: flattened array of all (layer_id, role, offset, size) in this pool + buffer_entries: np.ndarray # dtype=BUFFER_ENTRY_DTYPE + + @property + def pool_bytes(self) -> int: + return self.slot_bytes * self.num_slots + + @property + def unique_layers(self) -> Set[int]: + return set(int(entry["layer_id"]) for entry in self.buffer_entries) + + @property + def unique_roles(self) -> Set[int]: + return set(int(entry["role"]) for entry in self.buffer_entries) + + def get_slot_address(self, slot_id: int) -> int: + if slot_id >= self.num_slots: + raise ValueError(f"slot_id {slot_id} >= num_slots {self.num_slots}") + return self.base_address + slot_id * self.slot_bytes + + def get_device_pointer(self, slot_id: int, layer_id: int, role_enum: int) -> int: + if slot_id >= self.num_slots: + raise ValueError(f"slot_id {slot_id} >= num_slots {self.num_slots}") + + for entry in self.buffer_entries: + if entry["layer_id"] == layer_id and entry["role"] == role_enum: + slot_base = self.base_address + slot_id * self.slot_bytes + return slot_base + int(entry["offset"]) + + raise ValueError(f"Buffer not found: layer_id={layer_id}, role_enum={role_enum}") + + def __repr__(self) -> str: + return ( + f"PoolDescriptor(base=0x{self.base_address:x}, " + f"slot_bytes={self.slot_bytes}, num_slots={self.num_slots}, " + f"layers={len(self.unique_layers)}, roles={len(self.unique_roles)}" + ) + + +@dataclass +class KVCachePageTable: + """ + Multi-dimensional KV cache page table + + Structure: + KVCachePageTable + └── PoolGroups (List[List[PoolDescriptor]]) + ├── PoolGroup 0: List of PoolDescriptors + │ ├── Pool 0: PoolDescriptor + │ └── Pool 1: PoolDescriptor + │ + └── PoolGroup 1: List of PoolDescriptors + ├── Pool 0: PoolDescriptor + └── Pool 1: PoolDescriptor + + Relationships: + - pools[pg_idx] = List[PoolDescriptor] (all pools in same PoolGroup) + - All pools in pools[pg_idx] share the same lifecycle + """ + + tokens_per_block: int + num_layers: int + pools: List[List[PoolDescriptor]] # pools[pg_idx][pool_idx] → PoolDescriptor + + @property + def num_pool_groups(self) -> int: + return len(self.pools) + + @property + def total_pools(self) -> int: + return sum(len(pg_pools) for pg_pools in self.pools) + + @property + def total_buffer_entries(self) -> int: + return sum(pool.num_buffer_entries for pg_pools in self.pools for pool in pg_pools) + + @property + def total_pool_bytes(self) -> int: + return sum(pool.pool_bytes for pg_pools in self.pools for pool in pg_pools) + + @property + def total_slots(self) -> int: + return sum(pool.num_slots for pg_pools in self.pools for pool in pg_pools) + + def get_pool(self, pg_idx: int, pool_idx: int) -> PoolDescriptor: + return self.pools[pg_idx][pool_idx] + + def get_device_pointer( + self, pg_idx: int, pool_idx: int, slot_id: int, layer_id: int, role: str + ) -> int: + pool = self.pools[pg_idx][pool_idx] + role_enum = self.role_to_enum(role) + return pool.get_device_pointer(slot_id, layer_id, role_enum) + + def __repr__(self) -> str: + return ( + f"KVCachePageTable(poolgroups={self.num_pool_groups}, " + f"pools={self.total_pools}, layers={self.num_layers})" + ) diff --git a/tensorrt_llm/_torch/disaggregation/resource/__init__.py b/tensorrt_llm/_torch/disaggregation/resource/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py b/tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py new file mode 100644 index 0000000000..a08bcc12dd --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass +from typing import List + +from tensorrt_llm._torch.disaggregation.base.region import ( + DataLayout, + MemRegionGroup, + RegionExtractorBase, + SpecRegion, +) +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm._utils import get_size_in_bytes + + +@dataclass +class KVPoolAttrs: + """Attributes for a single (primary) KV memory pool.""" + + pool_ptrs: List[int] + block_bytes: List[int] + + +class KVRegionExtractorV1(RegionExtractorBase): + """ + Descriptor and region extractor for KV cache pool managed by KVCacheManager. + Provides region descriptors for adapting block-wise view. + """ + + def __init__(self, kv_arg: KVCacheManager | KVPoolAttrs): + if isinstance(kv_arg, KVPoolAttrs): + self._kv_pool_attrs = kv_arg + elif isinstance(kv_arg, KVCacheManager): + self._kv_pool_attrs = self._attrs_from_manager(kv_arg) + else: + raise TypeError( + f"kv_cache_manager must be KVCacheManager or KVPoolAttrs, got {type(kv_arg)}" + ) + self._data_layout = DataLayout.HND + + @staticmethod + def _attrs_from_manager(manager: KVCacheManager) -> KVPoolAttrs: + try: + pools = manager.get_unique_primary_pool() + except Exception as ex: + raise ValueError(f"Failed to get pool(s): {ex}") + + pool_list = list(pools) if isinstance(pools, (list, tuple)) else [pools] + elem_bytes = get_size_in_bytes(1, manager.dtype) + ptrs, block_sizes = [], [] + + for p in pool_list: + if hasattr(p, "data_ptr") and callable(p.data_ptr): + try: + ptr = int(p.data_ptr()) + except Exception as ex: + raise ValueError(f"Fail to call data_ptr(): {ex}") + elif isinstance(p, int): + ptr = int(p) + else: + raise ValueError(f"Pool object lacks 'data_ptr' and is not int: {p!r}") + ptrs.append(ptr) + + try: + if hasattr(p, "__getitem__") and hasattr(p[0], "numel"): + n = int(p[0].numel()) + elif hasattr(p, "numel") and callable(p.numel): + n = int(p.numel()) + else: + raise RuntimeError("Cannot determine element count") + except Exception as ex: + raise ValueError(f"Failed to get block size from {p!r}: {ex}") + + block_sizes.append(n * elem_bytes) + + return KVPoolAttrs(pool_ptrs=ptrs, block_bytes=block_sizes) + + def extract(self, region_ids: List[int]) -> SpecRegion: + """ + Given a list of region_ids, returns a single SpecRegion, + whose memory is a MemRegionGroup containing all blocks described + by region_ids. + """ + assert len(self._kv_pool_attrs.pool_ptrs) == 1 + pool_idx = 0 + attrs = self._kv_pool_attrs + ptrs = [ + attrs.pool_ptrs[pool_idx] + block_id * attrs.block_bytes[0] for block_id in region_ids + ] + memory = MemRegionGroup(ptrs=ptrs, bytes_per_region=attrs.block_bytes[0]) + return SpecRegion(memory=memory) diff --git a/tensorrt_llm/_torch/disaggregation/resource/kv_extractor_v2.py b/tensorrt_llm/_torch/disaggregation/resource/kv_extractor_v2.py new file mode 100644 index 0000000000..1ffcd1d4cc --- /dev/null +++ b/tensorrt_llm/_torch/disaggregation/resource/kv_extractor_v2.py @@ -0,0 +1,77 @@ +from collections import defaultdict + +import numpy as np + +from tensorrt_llm._torch.disaggregation.native.region.page import ( + BUFFER_ENTRY_DTYPE, + KVCachePageTable, + PoolDescriptor, +) +from tensorrt_llm.runtime.kv_cache_manager_v2 import CacheTier, KVCacheManager + + +def build_page_table(manager: KVCacheManager) -> KVCachePageTable: + storage = manager._storage + config = manager._init_config + + gpu_level = 0 + for level_idx, cache_tier_config in enumerate(config.cache_tiers): + if cache_tier_config.tier == CacheTier.GPU_MEM: + gpu_level = level_idx + break + + buffer_by_pool = defaultdict(list) + + lc_to_pg_cache = {} + + for buffer_id, attr in storage._buffer_attr.items(): + layer_id, role = buffer_id + + lc_id = attr.life_cycle_id + if lc_id not in lc_to_pg_cache: + lc_to_pg_cache[lc_id] = storage.get_pool_group_index(lc_id) + pg_idx = lc_to_pg_cache[lc_id] + + pool_idx = attr.pool_index + pool_key = (pg_idx, pool_idx) + + buffer_by_pool[pool_key].append((layer_id, role, attr.offset, attr.size)) + + pools = [] + num_pool_groups = storage.num_pool_groups + pool_group_storage = storage._levels[gpu_level].storage._pool_groups + + for pg_idx in range(num_pool_groups): + pool_group = pool_group_storage[pg_idx] + num_pools = pool_group.num_pools + pg_pools = [] + + for pool_idx in range(num_pools): + pool = pool_group._pools[pool_idx] + + base_address = int(pool.slot_address(0)) + slot_bytes = int(pool.slot_size) + num_slots = int(pool.num_slots) + + pool_key = (pg_idx, pool_idx) + buffers_info = buffer_by_pool.get(pool_key, []) + + if buffers_info: + buffer_entries = np.array(buffers_info, dtype=BUFFER_ENTRY_DTYPE) + else: + buffer_entries = np.array([], dtype=BUFFER_ENTRY_DTYPE) + + pool_desc = PoolDescriptor( + base_address=base_address, + slot_bytes=slot_bytes, + num_slots=num_slots, + buffer_entries=buffer_entries, + ) + pg_pools.append(pool_desc) + pools.append(pg_pools) + + return KVCachePageTable( + tokens_per_block=config.tokens_per_block, + num_layers=len(config.layers), + pools=pools, + ) diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 34ad12632e..ef20a8e586 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -38,6 +38,10 @@ l0_a10: - unittest/disaggregated/test_messenger.py - unittest/disaggregated/test_disagg_cluster_manager_worker.py - unittest/disaggregated/test_cluster_storage.py + - unittest/disaggregated/test_extractor.py + - unittest/disaggregated/test_extractor_v2.py + - unittest/disaggregated/test_peer.py + - unittest/disaggregated/region/test_block.py - disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 2831bc95bc..4a5b4b4219 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -39,6 +39,10 @@ l0_h100: - unittest/disaggregated/test_remoteDictionary.py - unittest/disaggregated/test_agent_multi_backends.py - unittest/disaggregated/test_messenger.py + - unittest/disaggregated/test_extractor.py + - unittest/disaggregated/test_extractor_v2.py + - unittest/disaggregated/test_peer.py + - unittest/disaggregated/region/test_block.py - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse diff --git a/tests/unittest/disaggregated/region/test_block.py b/tests/unittest/disaggregated/region/test_block.py new file mode 100644 index 0000000000..7573612e5c --- /dev/null +++ b/tests/unittest/disaggregated/region/test_block.py @@ -0,0 +1,146 @@ +from tensorrt_llm._torch.disaggregation.base.region import ( + MemRegionGroup, + SpecRegion, + SpecRegionPair, +) +from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo +from tensorrt_llm._torch.disaggregation.native.region.block import ( + HeadMatchMapper, + HeadMismatchMapper, + IdentityMapper, +) + + +def make_rankinfo( + kv_heads_per_rank=2, + tokens_per_block=4, + dims_per_head=2, + element_bytes=1, + tp_size=2, + tp_rank=0, + dp_size=1, + dp_rank=0, + pp_size=1, + pp_rank=0, + cp_size=1, + cp_rank=0, + is_mla=False, +): + return RankInfo( + instance_name="rank", + instance_rank=0, + tp_size=tp_size, + tp_rank=tp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + pp_size=pp_size, + pp_rank=pp_rank, + cp_size=cp_size, + cp_rank=cp_rank, + device_id=0, + kv_heads_per_rank=kv_heads_per_rank, + tokens_per_block=tokens_per_block, + dims_per_head=dims_per_head, + element_bytes=element_bytes, + enable_attention_dp=False, + is_mla=is_mla, + layer_num_per_pp=[1], + kv_ptrs=[], + aux_ptrs=[], + server_endpoint="", + self_endpoint="", + transfer_engine_info=b"", + aux_meta=None, + ) + + +def test_mem_region_group(): + ptrs = [11, 22, 33] + bytes_per_region = 16 + region = MemRegionGroup(ptrs=ptrs, bytes_per_region=bytes_per_region) + assert list(region.ptrs) == ptrs + assert region.bytes_per_region == bytes_per_region + + +def test_spec_region_and_spec_region_pair(): + group_src = MemRegionGroup(ptrs=[101, 202], bytes_per_region=8) + group_dst = MemRegionGroup(ptrs=[303, 404], bytes_per_region=8) + spec_src = SpecRegion(memory=group_src, spec="spec_src") + spec_dst = SpecRegion(memory=group_dst, spec="spec_dst") + assert isinstance(spec_src, SpecRegion) + assert isinstance(spec_dst, SpecRegion) + pair = SpecRegionPair(src=spec_src, dst=spec_dst) + assert isinstance(pair, SpecRegionPair) + assert pair.src.memory.ptrs == [101, 202] + assert pair.dst.memory.ptrs == [303, 404] + assert pair.src.spec == "spec_src" + assert pair.dst.spec == "spec_dst" + + +def test_identity_mapper(): + src_group = MemRegionGroup(ptrs=[100, 200], bytes_per_region=32) + dst_group = MemRegionGroup(ptrs=[300, 400], bytes_per_region=32) + src_spec = SpecRegion(memory=src_group, spec="a") + dst_spec = SpecRegion(memory=dst_group, spec="b") + mapper = IdentityMapper() + result = mapper.map(src_spec, dst_spec) + assert isinstance(result, SpecRegionPair) + assert list(result.src.memory.ptrs) == [100, 200] + assert list(result.dst.memory.ptrs) == [300, 400] + assert result.src.memory.bytes_per_region == 32 + assert result.dst.memory.bytes_per_region == 32 + + +def test_head_match_mapper(): + self_ri = make_rankinfo(kv_heads_per_rank=2) + peer_ri = make_rankinfo(kv_heads_per_rank=2) + transfer_layers = 2 + src_layer_off = 1 + dst_layer_off = 1 + src_group = MemRegionGroup(ptrs=[10, 20], bytes_per_region=1) + dst_group = MemRegionGroup(ptrs=[30, 40], bytes_per_region=1) + src_spec = SpecRegion(memory=src_group, spec="srcspec") + dst_spec = SpecRegion(memory=dst_group, spec="dstspec") + mapper = HeadMatchMapper(transfer_layers, src_layer_off, dst_layer_off, self_ri, peer_ri) + result = mapper.map(src_spec, dst_spec) + expected_off = ( + transfer_layers + * mapper._kv_factor + * self_ri.kv_heads_per_rank + * self_ri.tokens_per_block + * self_ri.dims_per_head + * self_ri.element_bytes + ) + assert list(result.src.memory.ptrs) == [10 + mapper._src_block_off, 20 + mapper._src_block_off] + assert list(result.dst.memory.ptrs) == [30 + mapper._dst_block_off, 40 + mapper._dst_block_off] + assert result.src.memory.bytes_per_region == expected_off + assert result.dst.memory.bytes_per_region == expected_off + + +def test_head_mismatch_mapper(): + self_ri = make_rankinfo(kv_heads_per_rank=2, tp_size=2, tp_rank=1) + peer_ri = make_rankinfo(kv_heads_per_rank=4, tp_size=4, tp_rank=2) + transfer_layers = 1 + src_layer_off = 0 + peer_layer_off = 1 + src_group = MemRegionGroup(ptrs=[111], bytes_per_region=32) + dst_group = MemRegionGroup(ptrs=[222], bytes_per_region=32) + src_spec = SpecRegion(memory=src_group, spec="srcspec") + dst_spec = SpecRegion(memory=dst_group, spec="dstspec") + mapper = HeadMismatchMapper(transfer_layers, src_layer_off, peer_layer_off, self_ri, peer_ri) + result = mapper.map(src_spec, dst_spec) + expected_frag_count = self_ri.kv_factor * transfer_layers + assert isinstance(result, SpecRegionPair) + assert len(result.src.memory.ptrs) == expected_frag_count + assert len(result.dst.memory.ptrs) == expected_frag_count + assert all(isinstance(x, int) for x in result.src.memory.ptrs) + assert all(isinstance(x, int) for x in result.dst.memory.ptrs) + assert result.src.memory.bytes_per_region == mapper._bytes_cont_heads + assert result.dst.memory.bytes_per_region == mapper._bytes_cont_heads + + +def test_rankinfo_kv_factor(): + ri1 = make_rankinfo(is_mla=False) + ri2 = make_rankinfo(is_mla=True) + assert ri1.kv_factor == 2 + assert ri2.kv_factor == 1 diff --git a/tests/unittest/disaggregated/test_extractor.py b/tests/unittest/disaggregated/test_extractor.py new file mode 100644 index 0000000000..d3b70e3f1b --- /dev/null +++ b/tests/unittest/disaggregated/test_extractor.py @@ -0,0 +1,101 @@ +import pytest + +from tensorrt_llm._torch.disaggregation.base.region import MemRegionGroup, SpecRegion +from tensorrt_llm._torch.disaggregation.resource.kv_extractor import KVRegionExtractorV1 +from tensorrt_llm._torch.pyexecutor.resource_manager import ( + CacheTypeCpp, + DataType, + KvCacheConfig, + KVCacheManager, + Mapping, +) + + +class DummyRankInfo: + instance_name = "dummy" + instance_rank = 0 + tp_size = 1 + tp_rank = 0 + pp_size = 1 + pp_rank = 0 + dp_size = 1 + dp_rank = 0 + cp_size = 1 + cp_rank = 0 + device_id = 0 + kv_heads_per_rank = 8 + tokens_per_block = 32 + dims_per_head = 16 + element_bytes = 2 + enable_attention_dp = False + is_mla = False + layer_num_per_pp = [1] + + @property + def kv_factor(self) -> int: + return 2 if not self.is_mla else 1 + + +@pytest.mark.cuda +def test_extract(): + num_layers = 1 + num_kv_heads = 8 + head_dim = 16 + tokens_per_block = 32 + max_seq_len = 128 + max_batch_size = 2 + dtype = DataType.HALF + mapping = Mapping(world_size=1, rank=0, tp_size=1, pp_size=1, gpus_per_node=1) + kv_cache_config = KvCacheConfig( + max_tokens=512, + free_gpu_memory_fraction=0.1, + max_attention_window=None, + enable_block_reuse=False, + event_buffer_max_size=0, + onboard_blocks=0, + host_cache_size=0, + enable_partial_reuse=False, + copy_on_partial_reuse=False, + sink_token_length=0, + max_util_for_resume=1, + ) + kv_cache_type = CacheTypeCpp.SELF + + manager = KVCacheManager( + kv_cache_config=kv_cache_config, + kv_cache_type=kv_cache_type, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + tokens_per_block=tokens_per_block, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + mapping=mapping, + dtype=dtype, + ) + + extractor = KVRegionExtractorV1(manager) + region_ids = [0, 1] + spec_region = extractor.extract(region_ids) + + assert isinstance(spec_region, SpecRegion) + memory = spec_region.memory + assert isinstance(memory, MemRegionGroup) + assert len(memory.ptrs) == len(region_ids) + assert memory.bytes_per_region > 0 + + pool_ptrs = manager.get_unique_primary_pool() + if hasattr(pool_ptrs, "__getitem__"): + if hasattr(pool_ptrs[0], "data_ptr"): + pool_base_ptr = int(pool_ptrs[0].data_ptr()) + else: + pool_base_ptr = int(pool_ptrs[0]) + else: + pool_base_ptr = ( + int(pool_ptrs.data_ptr()) if hasattr(pool_ptrs, "data_ptr") else int(pool_ptrs) + ) + expected_block_bytes = memory.bytes_per_region + expected_ptrs = [pool_base_ptr + block_id * expected_block_bytes for block_id in region_ids] + assert list(memory.ptrs) == expected_ptrs + + manager.shutdown() diff --git a/tests/unittest/disaggregated/test_extractor_v2.py b/tests/unittest/disaggregated/test_extractor_v2.py new file mode 100644 index 0000000000..06c243c78d --- /dev/null +++ b/tests/unittest/disaggregated/test_extractor_v2.py @@ -0,0 +1,78 @@ +import pytest + +from tensorrt_llm._torch.disaggregation.resource.kv_extractor_v2 import build_page_table +from tensorrt_llm.runtime.kv_cache_manager_v2 import ( + AttentionLayerConfig, + BufferConfig, + GpuCacheTierConfig, + KVCacheManager, + KVCacheManagerConfig, +) + + +@pytest.fixture +def simple_manager(): + layers = [ + AttentionLayerConfig( + layer_id=0, + sliding_window_size=None, + num_sink_tokens=0, + buffers=[ + BufferConfig(role=0, size=8192), + BufferConfig(role=1, size=8192), + ], + ), + AttentionLayerConfig( + layer_id=1, + sliding_window_size=None, + num_sink_tokens=0, + buffers=[ + BufferConfig(role=0, size=8192), + BufferConfig(role=1, size=8192), + ], + ), + ] + + cache_tiers = [ + GpuCacheTierConfig( + quota=100 * 1024 * 1024, # 100MB + ), + ] + + config = KVCacheManagerConfig( + tokens_per_block=64, + layers=layers, + vocab_size=50257, + cache_tiers=cache_tiers, + ) + + return KVCacheManager(config) + + +def test_build_page_table(simple_manager): + page_table = build_page_table(simple_manager) + + # Check basic properties + assert page_table.tokens_per_block == 64 + assert page_table.num_layers == 2 + assert page_table.num_pool_groups >= 1 + assert page_table.total_pools > 0 + + # Check pools are created + assert len(page_table.pools) > 0 + assert all(len(pg_pools) > 0 for pg_pools in page_table.pools) + + # Check first pool has valid properties + pool = page_table.pools[0][0] + assert pool.base_address > 0 + assert pool.slot_bytes > 0 + assert pool.num_slots > 0 + assert pool.pool_bytes == pool.slot_bytes * pool.num_slots + + # Check buffer entries exist + assert len(pool.buffer_entries) > 0 + + print(f"\n Page table: {page_table}") + print(f" Total pools: {page_table.total_pools}") + print(f" Pools: {page_table.pools}") + print(f" Total size: {page_table.total_pool_bytes / (1024**2):.2f} MB") diff --git a/tests/unittest/disaggregated/test_peer.py b/tests/unittest/disaggregated/test_peer.py new file mode 100644 index 0000000000..03e1a99d02 --- /dev/null +++ b/tests/unittest/disaggregated/test_peer.py @@ -0,0 +1,297 @@ +import pytest + +from tensorrt_llm._torch.disaggregation.native.peer import PeerOverlap, PeerRegistrar, RankInfo +from tensorrt_llm._torch.disaggregation.native.region.block import ( + HeadMatchMapper, + HeadMismatchMapper, + IdentityMapper, +) +from tensorrt_llm._torch.disaggregation.resource.kv_extractor import ( + KVPoolAttrs, + KVRegionExtractorV1, +) + + +def make_rankinfo( + instance_name="self", + instance_rank=0, + tp_size=2, + tp_rank=0, + pp_size=1, + pp_rank=0, + dp_size=1, + dp_rank=0, + cp_size=1, + cp_rank=0, + kv_heads_per_rank=2, + tokens_per_block=16, + dims_per_head=8, + element_bytes=2, + is_mla=False, + enable_attention_dp=False, + layer_num_per_pp=None, +): + if layer_num_per_pp is None: + layer_num_per_pp = [2] * pp_size + return RankInfo( + instance_name=instance_name, + instance_rank=instance_rank, + tp_size=tp_size, + tp_rank=tp_rank, + pp_size=pp_size, + pp_rank=pp_rank, + dp_size=dp_size, + dp_rank=dp_rank, + cp_size=cp_size, + cp_rank=cp_rank, + device_id=0, + kv_heads_per_rank=kv_heads_per_rank, + tokens_per_block=tokens_per_block, + dims_per_head=dims_per_head, + element_bytes=element_bytes, + enable_attention_dp=enable_attention_dp, + is_mla=is_mla, + layer_num_per_pp=layer_num_per_pp, + kv_ptrs=[], + aux_ptrs=[], + server_endpoint="", + self_endpoint="", + transfer_engine_info=b"", + aux_meta=None, + ) + + +def _make_peer_registrar_and_peer_ri(self_ri, peer_ri): + pool_attrs = KVPoolAttrs(pool_ptrs=[1234], block_bytes=[1024]) + real_kv_extractor = KVRegionExtractorV1(pool_attrs) + reg = PeerRegistrar(self_ri, real_kv_extractor) + return reg, peer_ri + + +def test_basic_overlap(): + self_ri = make_rankinfo( + "self", + pp_size=1, + pp_rank=0, + tp_size=1, + tp_rank=0, + cp_size=1, + cp_rank=0, + layer_num_per_pp=[2], + ) + peer_ri = make_rankinfo( + "peer", + pp_size=2, + pp_rank=0, + tp_size=1, + tp_rank=0, + cp_size=1, + cp_rank=0, + layer_num_per_pp=[1, 1], + ) + reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri) + overlap = reg.get_peer_overlap(peer_ri, peer_dp_rank=0) + assert overlap.overlap_pp_size == 2 + assert overlap.target_peer_pp_layer_num == [1, 1] + assert overlap.overlap_tp_size == 1 + assert overlap.ranks == [0, 1] + + +def test_no_overlap(): + self_ri = make_rankinfo( + "self", + pp_size=1, + pp_rank=0, + tp_size=1, + tp_rank=0, + cp_size=1, + cp_rank=0, + layer_num_per_pp=[0], + ) + peer_ri = make_rankinfo( + "peer", + pp_size=1, + pp_rank=0, + tp_size=1, + tp_rank=0, + cp_size=1, + cp_rank=0, + layer_num_per_pp=[2], + ) + reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri) + overlap = reg.get_peer_overlap(peer_ri, 0) + assert overlap.overlap_pp_size == 0 + assert overlap.target_peer_pp_layer_num == [] + assert overlap.ranks == [] + + +def test_pp_ratio_peer_smaller(): + self_ri = make_rankinfo( + "self", + pp_size=2, + pp_rank=1, + tp_size=1, + tp_rank=0, + cp_size=1, + cp_rank=0, + layer_num_per_pp=[1, 2], + ) + peer_ri = make_rankinfo( + "peer", + pp_size=1, + pp_rank=0, + tp_size=1, + tp_rank=0, + cp_size=1, + cp_rank=0, + layer_num_per_pp=[3], + ) + reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri) + overlap = reg.get_peer_overlap(peer_ri, 0) + assert overlap.overlap_pp_size > 0 + assert sum(overlap.target_peer_pp_layer_num) > 0 + assert all(r >= 0 for r in overlap.ranks) + + +def test_tp_overlap(): + self_ri = make_rankinfo( + "self", tp_size=2, tp_rank=1, pp_size=1, pp_rank=0, cp_size=1, cp_rank=0 + ) + peer_ri = make_rankinfo( + "peer", tp_size=4, tp_rank=0, pp_size=1, pp_rank=0, cp_size=1, cp_rank=0 + ) + reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri) + overlap = reg.get_peer_overlap(peer_ri, 0) + assert overlap.overlap_tp_size in [1, 2] + assert all(isinstance(r, int) for r in overlap.ranks) + + +def test_cp_overlap(): + self_ri = make_rankinfo( + "self", cp_size=2, cp_rank=1, pp_size=1, pp_rank=0, tp_size=1, tp_rank=0 + ) + peer_ri = make_rankinfo( + "peer", cp_size=4, cp_rank=0, pp_size=1, pp_rank=0, tp_size=1, tp_rank=0 + ) + reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri) + overlap = reg.get_peer_overlap(peer_ri, 0) + assert overlap.overlap_cp_size in [1, 2] + assert all(isinstance(r, int) for r in overlap.ranks) + + +def test_multiple_overlap(): + self_ri = make_rankinfo( + "self", + pp_size=2, + pp_rank=1, + tp_size=2, + tp_rank=1, + cp_size=2, + cp_rank=1, + layer_num_per_pp=[1, 2], + ) + peer_ri = make_rankinfo( + "peer", + pp_size=4, + pp_rank=2, + tp_size=4, + tp_rank=3, + cp_size=4, + cp_rank=0, + layer_num_per_pp=[1, 1, 1, 1], + ) + reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri) + overlap = reg.get_peer_overlap(peer_ri, peer_dp_rank=0) + expected_overlap = PeerOverlap( + overlap_pp_size=2, + overlap_tp_size=2, + overlap_cp_size=2, + duplicate_head_factor=1, + peer_duplicate_head_factor=2, + target_peer_pp_layer_num=[1, 1], + ranks=[26, 27, 30, 31, 42, 43, 46, 47], + ) + assert overlap == expected_overlap + + +def _make_peer_registrar(self_rankinfo): + pool_attrs = KVPoolAttrs(pool_ptrs=[1234], block_bytes=[1024]) + real_kv_extractor = KVRegionExtractorV1(pool_attrs) + reg = PeerRegistrar(self_rankinfo, real_kv_extractor) + return reg + + +def test_peer_registrar_register_and_get(): + self_rankinfo = make_rankinfo(instance_name="local") + reg = _make_peer_registrar(self_rankinfo) + peer_ri = make_rankinfo(instance_name="peer", instance_rank=1, layer_num_per_pp=[2]) + reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri) + assert reg.get_peer_rank_info("peer", 1) == peer_ri + + +def test_peer_registrar_unregister(): + self_rankinfo = make_rankinfo(instance_name="local") + reg = _make_peer_registrar(self_rankinfo) + peer_ri = make_rankinfo(instance_name="peer", instance_rank=1, layer_num_per_pp=[2]) + reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri) + reg.unregister(peer_ri.instance_name, peer_ri.instance_rank) + with pytest.raises(KeyError): + reg.get_peer_rank_info("peer", 1) + + +def test_peer_registrar_incompatible_peer_raises(): + self_rankinfo = make_rankinfo(instance_name="local") + reg = _make_peer_registrar(self_rankinfo) + peer_ri = make_rankinfo(instance_name="peer", instance_rank=3, is_mla=True) + with pytest.raises(ValueError): + reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri) + + +def test_peer_registrar_self_rank_info_property(): + self_rankinfo = make_rankinfo(instance_name="local") + reg = _make_peer_registrar(self_rankinfo) + assert reg.self_rank_info == self_rankinfo + + +def test_peer_registrar_get_kv_map_identity(): + self_rankinfo = make_rankinfo(instance_name="local") + reg = _make_peer_registrar(self_rankinfo) + peer_ri = make_rankinfo(instance_name="peer", instance_rank=1, layer_num_per_pp=[2]) + reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri) + mapper = reg.get_kv_map(peer_ri) + assert isinstance(mapper, IdentityMapper) + + +def test_peer_registrar_get_kv_map_head_match(): + self_rankinfo = make_rankinfo(instance_name="local") + reg = _make_peer_registrar(self_rankinfo) + peer_ri = make_rankinfo( + instance_name="peer", + instance_rank=2, + pp_size=2, + pp_rank=1, + layer_num_per_pp=[1, 1], + tokens_per_block=16, + dims_per_head=8, + ) + mapper = reg.get_kv_map(peer_ri) + assert isinstance(mapper, HeadMatchMapper) + + +def test_peer_registrar_get_kv_map_head_mismatch(): + self_rankinfo = make_rankinfo(instance_name="local") + reg = _make_peer_registrar(self_rankinfo) + peer_ri = make_rankinfo( + instance_name="peer", + instance_rank=3, + pp_size=1, + pp_rank=0, + tp_size=1, + tp_rank=0, + kv_heads_per_rank=4, + tokens_per_block=16, + dims_per_head=8, + layer_num_per_pp=[2], + ) + mapper = reg.get_kv_map(peer_ri) + assert isinstance(mapper, HeadMismatchMapper)