[TRTLLM-9527][feat] Modularization of the transceiver for KV manager v2 (step 4) (#11225)

Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
Shi Xiaowei 2026-02-06 20:15:18 +08:00 committed by GitHub
parent 66caa67357
commit b1268e1b37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 1806 additions and 0 deletions

View File

@ -0,0 +1,124 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import IntFlag, auto
from typing import List, NamedTuple, Optional
@dataclass(frozen=True)
class IndexRange:
"""
Represents a closed interval [start, end], with both bounds >= 0.
Commonly used for indexing layers, heads, or tokens.
"""
start: int
end: int
def __post_init__(self):
if not (isinstance(self.start, int) and isinstance(self.end, int)):
raise TypeError("start and end must be integers")
if self.start < 0 or self.end < 0:
raise ValueError("start and end must be >= 0")
if self.end < self.start:
raise ValueError("end must be >= start")
class MemRegion(NamedTuple):
"""Describes a block of memory by starting pointer and size in bytes."""
ptr: int
bytes: int
class MemRegionGroup(NamedTuple):
"""Describes a block of memory by starting pointer and size in bytes."""
ptrs: List[int]
bytes_per_region: int
class DataRole(IntFlag):
"""Logical role(s) a memory region plays. Supports combinations."""
KEY = auto()
VALUE = auto()
class DataLayout(IntFlag):
"""Possible orders for storing data in memory."""
HND = auto() # (head, seq_len, dim)
NHD = auto() # (seq_len, head, dim)
@dataclass(frozen=True)
class RegionSpec:
"""
Specifies a (potentially partial) region of the cache.
Extend this base class for additional axis or specialization.
"""
layers: Optional[IndexRange] = None
@dataclass(frozen=True)
class KVRegionSpec(RegionSpec):
"""
Specifies a region within the Key/Value cache, with optional axes.
"""
role: DataRole = DataRole.KEY | DataRole.VALUE
heads: Optional[IndexRange] = None
tokens: Optional[IndexRange] = None
class SpecRegion(NamedTuple):
"""
Associates a memory region with its semantic specifier.
"""
memory: MemRegion | MemRegionGroup
spec: RegionSpec = None
class RegionExtractorBase(ABC):
"""
Interface for extracting region descriptors from some backing store.
"""
@abstractmethod
def extract(self, region_ids: Optional[List[int]] = None) -> List[SpecRegion]:
"""
Args:
region_ids: (Optional) List of integer region identifiers to extract.
Returns:
List of Regions for corresponding regions.
"""
...
class SpecRegionPair(NamedTuple):
"""
Maps a source descriptor to a destination descriptor
(e.g., when copying or reindexing regions).
"""
src: SpecRegion
dst: SpecRegion
class RegionMapperBase(ABC):
"""
Maps a batch of region descriptors to corresponding destination(s).
"""
@abstractmethod
def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair:
"""
Args:
src_regions: List of source Regions.
dst_regions: List of destination Regions.
Returns:
List of RegionPairs mapping source to destination.
"""
...

View File

@ -0,0 +1,286 @@
from dataclasses import dataclass, field
from typing import Dict, List
from tensorrt_llm import logger
from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo
from tensorrt_llm._torch.disaggregation.native.region.block import (
HeadMatchMapper,
HeadMismatchMapper,
IdentityMapper,
RegionMapperBase,
)
from tensorrt_llm._torch.disaggregation.resource.kv_extractor import (
KVPoolAttrs,
KVRegionExtractorV1,
)
@dataclass
class PeerOverlap:
overlap_pp_size: int = 0
overlap_tp_size: int = 0
overlap_cp_size: int = 0
duplicate_head_factor: int = 1
peer_duplicate_head_factor: int = 1
target_peer_pp_layer_num: List[int] = field(default_factory=list)
ranks: List[int] = field(default_factory=list)
class PeerRegistrar:
def __init__(self, self_rank_info: RankInfo, self_extractor: KVRegionExtractorV1):
self._ri = self_rank_info
self._peer_ri_cache: Dict[str, RankInfo] = {}
self._kv_map_cache: Dict[str, RegionMapperBase] = {}
self._self_ext_cache = self_extractor
self._peer_ext_cache: Dict[str, KVRegionExtractorV1] = {}
self._overlap_cache: Dict[str, PeerOverlap] = {}
def _block_size(self, layer_num: int, ri: RankInfo) -> int:
return (
layer_num
* ri.kv_factor
* ri.kv_heads_per_rank
* ri.tokens_per_block
* ri.dims_per_head
* ri.element_bytes
)
def register(self, peer_name: str, peer_rank: int, peer_ri: RankInfo):
# TODO: check if peer is valid for registration
assert self._self_ext_cache is not None
if not self._check_peer_compatible(peer_ri):
raise ValueError(
f"PeerRegistrar.register: peer {peer_name} (rank={peer_rank}) is incompatible with local rank."
)
key = self._unique_key(peer_name, peer_rank)
self._peer_ri_cache[key] = peer_ri
peer_ri = self.get_peer_rank_info(peer_name, peer_rank)
layer_num = peer_ri.layer_num_per_pp[peer_ri.pp_rank]
block_size = self._block_size(layer_num, peer_ri)
extractor = KVRegionExtractorV1(
KVPoolAttrs(pool_ptrs=peer_ri.kv_ptrs, block_bytes=[block_size])
)
self._peer_ext_cache[key] = extractor
def peer_extractor(self, peer_name: str, peer_rank: int) -> KVRegionExtractorV1:
return self._peer_ext_cache[self._unique_key(peer_name, peer_rank)]
@property
def self_extractor(self) -> KVRegionExtractorV1:
assert self._self_ext_cache is not None
return self._self_ext_cache
def unregister(self, peer_name: str, peer_rank: int):
key = self._unique_key(peer_name, peer_rank)
if key in self._peer_ri_cache:
del self._peer_ri_cache[key]
if key in self._peer_ext_cache:
del self._peer_ext_cache[key]
if key in self._kv_map_cache:
del self._kv_map_cache[key]
def get_peer_rank_info(self, peer_name: str, peer_rank: int):
return self._peer_ri_cache[self._unique_key(peer_name, peer_rank)]
@property
def self_rank_info(self) -> RankInfo:
return self._ri
def _unique_key(self, name: str, rank: int) -> str:
return name + str(rank)
def _check_peer_compatible(self, peer_ri: RankInfo) -> bool:
if self._ri.is_mla != peer_ri.is_mla:
logger.warning(
"PeerRegistrar: compatibility check failed: 'is_mla' differs "
f"(local={self._ri.is_mla}, peer={peer_ri.is_mla})."
)
return False
if self._ri.cp_size != 1 or peer_ri.cp_size != 1:
logger.warning(
"PeerRegistrar: unsupported configuration: context parallelism (cp_size) "
f"must be 1 for both local and peer ranks (local={self._ri.cp_size}, peer={peer_ri.cp_size})."
)
return False
if self._ri.element_bytes != peer_ri.element_bytes:
logger.warning(
"PeerRegistrar: element size mismatch "
f"(local={self._ri.element_bytes} bytes, peer={peer_ri.element_bytes} bytes)."
)
return False
if self._ri.tokens_per_block != peer_ri.tokens_per_block:
logger.warning(
"PeerRegistrar: tokens_per_block mismatch "
f"(local={self._ri.tokens_per_block}, peer={peer_ri.tokens_per_block})."
)
return False
if self._ri.dims_per_head != peer_ri.dims_per_head:
logger.warning(
"PeerRegistrar: dims_per_head mismatch "
f"(local={self._ri.dims_per_head}, peer={peer_ri.dims_per_head})."
)
return False
self_layers = sum(self._ri.layer_num_per_pp)
peer_layers = sum(peer_ri.layer_num_per_pp)
if self_layers != peer_layers:
logger.warning(
"PeerRegistrar: total layer count mismatch "
f"(local={self_layers}, peer={peer_layers})."
)
return False
if self._ri.is_mla:
if peer_ri.kv_heads_per_rank != 1 or self._ri.kv_heads_per_rank != 1:
logger.warning(
"PeerRegistrar: MLA mode requires exactly 1 KV head per rank for both local and peer."
f" (local={self._ri.kv_heads_per_rank}, peer={peer_ri.kv_heads_per_rank})"
)
return False
return True
def _tp_per_dp(self, info: RankInfo) -> int:
return (
info.tp_size // info.dp_size
if getattr(info, "enable_attention_dp", False)
else info.tp_size
)
def get_kv_map(self, peer_ri: RankInfo):
key = self._unique_key(peer_ri.instance_name, peer_ri.instance_rank)
if key in self._kv_map_cache:
return self._kv_map_cache[key]
self_tp_per_dp = self._tp_per_dp(self._ri)
peer_tp_per_dp = self._tp_per_dp(peer_ri)
is_dup_head = (
self._ri.kv_heads_per_rank * self_tp_per_dp
!= peer_ri.kv_heads_per_rank * peer_tp_per_dp
)
head_match = is_dup_head or self._ri.is_mla or self_tp_per_dp == peer_tp_per_dp
logger.debug(
"KVMapperFactory.get_kv_map: "
f"head_match={head_match}, is_dup_head={is_dup_head}, self_is_mla={self._ri.is_mla}, "
f"self_tp_per_dp={self_tp_per_dp}, peer_tp_per_dp={peer_tp_per_dp}"
)
# fast identity when write_all and same pp_size
if head_match and self._ri.pp_size == peer_ri.pp_size:
mapper = IdentityMapper()
self._kv_map_cache[key] = mapper
return mapper
# compute overlapping layers
self_start_layer = sum(self._ri.layer_num_per_pp[: self._ri.pp_rank])
self_end_layer = self_start_layer + self._ri.layer_num_per_pp[self._ri.pp_rank]
peer_start_layer = sum(peer_ri.layer_num_per_pp[: peer_ri.pp_rank])
peer_end_layer = peer_start_layer + peer_ri.layer_num_per_pp[peer_ri.pp_rank]
start = max(self_start_layer, peer_start_layer)
end = min(self_end_layer, peer_end_layer)
transfer_layers = end - start
self_layer_offset = start - self_start_layer
peer_layer_offset = start - peer_start_layer
if head_match:
mapper = HeadMatchMapper(
transfer_layers=transfer_layers,
src_layer_off=self_layer_offset, # local layer offset
dst_layer_off=peer_layer_offset, # peer layer offset
self_ri=self._ri,
peer_ri=peer_ri,
)
self._kv_map_cache[key] = mapper
return mapper
# head mismatch case
mapper = HeadMismatchMapper(
transfer_layers=transfer_layers,
src_layer_off=self_layer_offset,
peer_layer_off=peer_layer_offset,
self_ri=self._ri,
peer_ri=peer_ri,
)
self._kv_map_cache[key] = mapper
return mapper
@staticmethod
def _find_overlap(self_val, peer_val, self_rank, peer_rank=None):
if self_val <= peer_val:
overlap = peer_val // self_val
start = self_rank * overlap + (peer_rank * peer_val if peer_rank is not None else 0)
end = start + overlap
else:
ratio = self_val // peer_val
start = (self_rank // ratio) + (peer_rank * peer_val if peer_rank is not None else 0)
overlap = 1
end = start + overlap
return overlap, start, end
def get_peer_overlap(self, peer_rank_info: RankInfo, peer_dp_rank: int) -> PeerOverlap:
peer_ri = peer_rank_info
key = self._unique_key(peer_ri.instance_name, peer_dp_rank)
if key in self._overlap_cache:
return self._overlap_cache[key]
# compute pp overlap and target layers
self_start_layer = sum(self._ri.layer_num_per_pp[: self._ri.pp_rank])
self_end_layer = self_start_layer + self._ri.layer_num_per_pp[self._ri.pp_rank]
pre = 0
tgt_pp_ranks: List[int] = []
tgt_pp_layer_num: List[int] = []
for p in range(peer_ri.pp_size):
peer_start_layer = pre
peer_end_layer = peer_start_layer + peer_ri.layer_num_per_pp[p]
if self_start_layer < peer_end_layer and self_end_layer > peer_start_layer:
tgt_pp_ranks.append(p)
tgt_pp_layer_num.append(
min(peer_end_layer, self_end_layer) - max(peer_start_layer, self_start_layer)
)
pre += peer_ri.layer_num_per_pp[p]
if tgt_pp_ranks == []:
# no overlap found
targets = PeerOverlap()
self._overlap_cache[key] = targets
return targets
peer_start_pp = tgt_pp_ranks[0]
overlap_pp_size = len(tgt_pp_ranks)
peer_end_pp = peer_start_pp + overlap_pp_size
# tp per dp-group
self_tp_per_dp = self._tp_per_dp(self._ri)
peer_tp_per_dp = self._tp_per_dp(peer_ri)
self_tp_rank_in_dp = self._ri.tp_rank % self_tp_per_dp
overlap_tp_size, peer_start_tp, peer_end_tp = self._find_overlap(
self_tp_per_dp, peer_tp_per_dp, self_tp_rank_in_dp, peer_dp_rank
)
overlap_cp_size, peer_start_cp, peer_end_cp = self._find_overlap(
self._ri.cp_size, peer_ri.cp_size, self._ri.cp_rank
)
ranks: List[int] = []
for pp in range(peer_start_pp, peer_end_pp):
for cp in range(peer_start_cp, peer_end_cp):
for tp in range(peer_start_tp, peer_end_tp):
ranks.append(pp * peer_ri.tp_size * peer_ri.cp_size + cp * peer_ri.tp_size + tp)
factor_self = self._ri.kv_heads_per_rank * self_tp_per_dp
factor_peer = peer_ri.kv_heads_per_rank * peer_tp_per_dp
dup_head = max(1, factor_self // factor_peer)
peer_dup_head = max(1, factor_peer // factor_self)
targets = PeerOverlap(
overlap_pp_size=overlap_pp_size,
overlap_tp_size=overlap_tp_size,
overlap_cp_size=overlap_cp_size,
duplicate_head_factor=dup_head,
peer_duplicate_head_factor=peer_dup_head,
target_peer_pp_layer_num=tgt_pp_layer_num,
ranks=ranks,
)
self._overlap_cache[key] = targets
return targets

View File

@ -0,0 +1,75 @@
from dataclasses import asdict, dataclass
from typing import List, Optional
import msgpack
from tensorrt_llm._torch.disaggregation.native.region.aux import AuxBufferMeta
@dataclass
class InstanceInfo:
instance_name: str
tp_size: int
pp_size: int
dp_size: int
cp_size: int
kv_heads_per_rank: int
tokens_per_block: int
dims_per_head: int
element_bytes: int
enable_attention_dp: bool
is_mla: bool
layer_num_per_pp: List[int]
sender_endpoints: List[str]
def to_bytes(self) -> bytes:
return msgpack.packb(asdict(self))
@classmethod
def from_bytes(cls, data: bytes) -> "InstanceInfo":
return cls(**msgpack.unpackb(data))
@dataclass
class RankInfo:
instance_name: str
instance_rank: int
tp_size: int
tp_rank: int
pp_size: int
pp_rank: int
dp_size: int
dp_rank: int
cp_size: int
cp_rank: int
device_id: int
kv_heads_per_rank: int
# [numLayers, kv_factor, heads, tokens, dims_per_head]
tokens_per_block: int
dims_per_head: int
element_bytes: int
enable_attention_dp: bool
is_mla: bool
layer_num_per_pp: List[int]
kv_ptrs: List[int]
aux_ptrs: List[int]
server_endpoint: str
self_endpoint: str
transfer_engine_info: bytes
aux_meta: Optional[AuxBufferMeta]
@property
def kv_factor(self) -> int:
return 2 if not self.is_mla else 1
def to_bytes(self) -> bytes:
data = asdict(self)
data["aux_meta"] = self.aux_meta.to_dict() if self.aux_meta is not None else None
return msgpack.packb(data)
@classmethod
def from_bytes(cls, data: bytes) -> "RankInfo":
unpacked = msgpack.unpackb(data)
if unpacked.get("aux_meta") is not None:
unpacked["aux_meta"] = AuxBufferMeta.from_dict(unpacked["aux_meta"])
return cls(**unpacked)

View File

@ -0,0 +1,183 @@
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import dataclass, field
from typing import Any
import torch
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
@dataclass
class AuxBufferMeta:
ptrs: list[int]
size: list[int]
item_sizes: list[int] = field(default_factory=list)
device: str = "cpu"
def to_dict(self) -> dict[str, Any]:
return {
"ptrs": self.ptrs,
"size": self.size,
"item_sizes": self.item_sizes,
"device": self.device,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AuxBufferMeta":
return cls(
ptrs=data["ptrs"],
size=data["size"],
item_sizes=data.get("item_sizes", []),
device=data.get("device", "cpu"),
)
class AuxBufferBase(ABC):
"""
Abstract base class defining the interface for auxiliary buffer management.
"""
@abstractmethod
def alloc_slot(self) -> int:
"""
Allocate a free slot and return its index.
"""
...
@abstractmethod
def free_slot(self, slot: int) -> None:
"""
Release the specified slot.
"""
...
@property
@abstractmethod
def meta(self) -> AuxBufferMeta:
"""
Retrieve meta-information about the underlying buffer(s).
Returns buffer info (e.g., pointers, sizes, device).
"""
...
@abstractmethod
def fill_slot(self, slot: int, request: LlmRequest) -> None:
"""
Fill/overwrite the contents of the given slot with data from the request.
"""
...
@abstractmethod
def get_slot_tokens(self, slot: int) -> tuple[list[int], list[int]]:
"""
Get the token data (e.g., first/draft tokens) from the specified slot.
"""
...
class AuxBuffer(AuxBufferBase):
def __init__(self, max_slot_num: int, beam_width: int, max_draft_len: int, device: str = "cpu"):
# public constructor args remain the same, internals are private
self._max_slot_num = int(max_slot_num)
self._beam_width = int(beam_width)
self._max_draft_len = int(max_draft_len)
self._device = device
self._free_slots = deque(list(range(self._max_slot_num)))
self._occupied_slots: set[int] = set()
self._slot_token_counts: dict[
int, tuple[int, int]
] = {} # slot -> (first_tokens_len, draft_tokens_len)
data_type = torch.int32
self._first_tokens_buffer = torch.empty(
self._max_slot_num, self._beam_width, dtype=data_type, device=self._device
)
self._draft_tokens_buffer = torch.empty(
self._max_slot_num, self._max_draft_len, dtype=data_type, device=self._device
)
self._meta = AuxBufferMeta(
ptrs=[self._first_tokens_buffer.data_ptr(), self._draft_tokens_buffer.data_ptr()],
size=[
self._first_tokens_buffer.numel() * self._first_tokens_buffer.element_size(),
self._draft_tokens_buffer.numel() * self._draft_tokens_buffer.element_size(),
],
item_sizes=[
self._first_tokens_buffer[0].numel() * self._first_tokens_buffer.element_size(),
self._draft_tokens_buffer[0].numel() * self._draft_tokens_buffer.element_size(),
],
device=self._device,
)
def alloc_slot(self) -> int:
if not self._free_slots:
raise ValueError(
f"No free auxiliary buffer slots available (max slots = {self._max_slot_num}). "
"All slots are currently occupied."
)
slot_id = self._free_slots.popleft()
if slot_id in self._occupied_slots:
# This should not happen — defensive check.
raise RuntimeError(
f"Invariant error: selected slot {slot_id} is already marked as occupied. "
"This indicates a bug in slot management."
)
self._occupied_slots.add(slot_id)
return slot_id
def free_slot(self, slot: int) -> None:
if slot not in self._occupied_slots:
raise ValueError(
f"Attempted to free slot {slot}, but that slot is not currently allocated. "
"Ensure `alloc_slot` was called and the slot wasn't freed already."
)
if slot < 0 or slot >= self._max_slot_num:
raise ValueError(
f"Invalid slot id {slot}. Valid slot indices are in the range 0..{self._max_slot_num - 1}."
)
self._occupied_slots.remove(slot)
self._free_slots.append(slot)
@property
def meta(self) -> AuxBufferMeta:
return self._meta
def fill_slot(self, slot: int, request: LlmRequest) -> None:
if slot not in self._occupied_slots:
raise ValueError(
f"Cannot fill slot {slot}: slot is not currently allocated. "
"Call `alloc_slot` first."
)
first_gen_tokens = request.get_last_tokens()
draft_tokens = request.py_draft_tokens
if len(first_gen_tokens) > self._beam_width:
raise ValueError(
f"`first_gen_tokens` length ({len(first_gen_tokens)}) exceeds `beam_width` ({self._beam_width}). "
"Consider truncating the token list or increasing the beam_width when creating the `AuxBuffer`."
)
if len(draft_tokens) > self._max_draft_len:
raise ValueError(
f"`draft_tokens` length ({len(draft_tokens)}) exceeds `max_draft_len` ({self._max_draft_len}). "
"Consider truncating draft tokens or increasing `max_draft_len` when creating the `AuxBuffer`."
)
self._first_tokens_buffer[slot][: len(first_gen_tokens)].copy_(
torch.tensor(first_gen_tokens, dtype=torch.int32, device=self._device)
)
self._draft_tokens_buffer[slot][: len(draft_tokens)].copy_(
torch.tensor(draft_tokens, dtype=torch.int32, device=self._device)
)
self._slot_token_counts[slot] = (len(first_gen_tokens), len(draft_tokens))
def get_slot_tokens(self, slot: int) -> tuple[list[int], list[int]]:
first_len, draft_len = self._slot_token_counts.get(
slot, (self._beam_width, self._max_draft_len)
)
first_gen_tokens = self._first_tokens_buffer[slot][:first_len].tolist()
draft_tokens = self._draft_tokens_buffer[slot][:draft_len].tolist()
return first_gen_tokens, draft_tokens

View File

@ -0,0 +1,216 @@
import numpy as np
from tensorrt_llm._torch.disaggregation.base.region import (
MemRegionGroup,
RegionMapperBase,
SpecRegion,
SpecRegionPair,
)
from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo
class IdentityMapper(RegionMapperBase):
"""
---- mapper_identity ----
Pass-through mapping. Do not change pointers or sizes.
src_ptrs: [ S0 ] [ S1 ] [ S2 ] ...
| | |
v v v
dst_ptrs: [ D0 ] [ D1 ] [ D2 ] ...
"""
def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair:
src_group = src_regions.memory
dst_group = dst_regions.memory
assert len(src_group.ptrs) == len(dst_group.ptrs), (
f"Number of regions of src({len(src_group.ptrs)}) and dst({len(dst_group.ptrs)}) must match"
)
new_src = MemRegionGroup(
ptrs=list(src_group.ptrs), bytes_per_region=src_group.bytes_per_region
)
new_dst = MemRegionGroup(
ptrs=list(dst_group.ptrs), bytes_per_region=dst_group.bytes_per_region
)
return SpecRegionPair(
src=SpecRegion(memory=new_src, spec=src_regions.spec),
dst=SpecRegion(memory=new_dst, spec=dst_regions.spec),
)
class HeadMatchMapper(RegionMapperBase):
"""
---- mapper_head_match ----
Move/copy entire contiguous block(s) (multi-layer fragment) as a single chunk.
Align by whole fragment size (frag_size) and apply a constant source/destination block offset.
src_ptrs: [ S0 ] [ S1 ] ...
| |
+ src_off + src_off
| |
[ S0 + src_off ] [ S1 + src_off ] -> (each points to a frag of size frag_size)
copy whole frag
| |
v v
[ D0 + dst_off ] [ D1 + dst_off ] -> (destination frags)
"""
def __init__(
self,
transfer_layers: int,
src_layer_off: int,
dst_layer_off: int,
self_ri: RankInfo,
peer_ri: RankInfo,
):
self._kv_factor = self_ri.kv_factor
self._frag_size = self._block_size(transfer_layers, self_ri)
self._src_block_off = self._block_size(src_layer_off, self_ri)
self._dst_block_off = self._block_size(dst_layer_off, peer_ri)
def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair:
src_group = src_regions.memory
dst_group = dst_regions.memory
assert len(src_group.ptrs) == len(dst_group.ptrs), (
f"Number of regions of src({len(src_group.ptrs)}) and dst({len(dst_group.ptrs)}) must match"
)
new_src_ptrs = [src_ptr + self._src_block_off for src_ptr in src_group.ptrs]
new_dst_ptrs = [dst_ptr + self._dst_block_off for dst_ptr in dst_group.ptrs]
new_src = MemRegionGroup(ptrs=new_src_ptrs, bytes_per_region=self._frag_size)
new_dst = MemRegionGroup(ptrs=new_dst_ptrs, bytes_per_region=self._frag_size)
return SpecRegionPair(
src=SpecRegion(memory=new_src, spec=src_regions.spec),
dst=SpecRegion(memory=new_dst, spec=dst_regions.spec),
)
def _block_size(self, layer_num: int, ri: RankInfo) -> int:
return (
layer_num
* ri.kv_factor
* ri.kv_heads_per_rank
* ri.tokens_per_block
* ri.dims_per_head
* ri.element_bytes
)
class HeadMismatchMapper(RegionMapperBase):
"""
---- mapper_head_mismatch ----
Fine-grained mapping when head counts or TP/DP partitioning differ.
Split layers into per-head (or contiguous-heads) fragments and map them individually.
Handles kv_factor (e.g., key+value duplication) and TP/DP head offsets.
Source (layers x heads):
L0: [S00 S01] [S02 S03] ...
L1: [S10 S11] [S12 S13] ...
Destination (layers x heads, different layout possible):
L0': [D00] [D01] [D02] ...
L1': [D10] [D11] ...
Mapping (each arrow = copy cont_heads_frag):
[S00 S01] -> [D00]
[S02 S03] -> [D01]
[S10 S11] -> [D02]
"""
def __init__(
self,
transfer_layers: int,
src_layer_off: int,
peer_layer_off: int,
self_ri: RankInfo,
peer_ri: RankInfo,
):
self._ri = self_ri
self._peer_ri = peer_ri
self._src_layer_off = src_layer_off
kv_factor = self_ri.kv_factor
self_tp_per_dp = self_ri.tp_size // self_ri.dp_size
peer_tp_per_dp = peer_ri.tp_size // peer_ri.dp_size
self_tp_rank = self_ri.tp_rank
peer_tp_rank = peer_ri.tp_rank
bytes_per_head = self._ri.tokens_per_block * self._ri.dims_per_head * self._ri.element_bytes
self._bytes_cont_heads = (
min(self._ri.kv_heads_per_rank, peer_ri.kv_heads_per_rank) * bytes_per_head
)
self._src_head_off, self._dst_head_off = self._compute_head_offsets(
self_tp_per_dp,
peer_tp_per_dp,
self_tp_rank,
peer_tp_rank,
self._bytes_cont_heads,
)
self._layer_indices = np.arange(transfer_layers, dtype=np.int64)
self._kv_indices = np.arange(kv_factor, dtype=np.int64)
self._peer_layer_off = peer_layer_off
def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair:
src_group = src_regions.memory
dst_group = dst_regions.memory
assert len(src_group.ptrs) == len(dst_group.ptrs), (
f"Number of regions of src({len(src_group.ptrs)}) and dst({len(dst_group.ptrs)}) must match"
)
src_bases = np.array(src_group.ptrs, dtype=np.int64)
dst_bases = np.array(dst_group.ptrs, dtype=np.int64)
src_frags = self._get_frags(
bases=src_bases,
layer_indices=self._src_layer_off + self._layer_indices,
layer_kv_num=self._get_layer_kv_num(self._ri),
kv_indices=self._kv_indices,
head_off=self._src_head_off,
kv_factor=self._kv_indices.size,
)
dst_frags = self._get_frags(
bases=dst_bases,
layer_indices=self._peer_layer_off + self._layer_indices,
layer_kv_num=self._get_layer_kv_num(self._peer_ri),
kv_indices=self._kv_indices,
head_off=self._dst_head_off,
kv_factor=self._kv_indices.size,
)
all_src_ptrs = [int(x) for x in src_frags.flatten()]
all_dst_ptrs = [int(x) for x in dst_frags.flatten()]
new_src = MemRegionGroup(ptrs=all_src_ptrs, bytes_per_region=self._bytes_cont_heads)
new_dst = MemRegionGroup(ptrs=all_dst_ptrs, bytes_per_region=self._bytes_cont_heads)
return SpecRegionPair(
src=SpecRegion(memory=new_src, spec=src_regions.spec),
dst=SpecRegion(memory=new_dst, spec=dst_regions.spec),
)
@staticmethod
def _compute_head_offsets(
self_tp_per_dp: int,
peer_tp_per_dp: int,
self_tp_rank: int,
peer_tp_rank: int,
bytes_cont_heads: int,
) -> tuple[int, int]:
if self_tp_per_dp == peer_tp_per_dp:
return 0, 0
ratio = max(self_tp_per_dp, peer_tp_per_dp) // min(self_tp_per_dp, peer_tp_per_dp)
if self_tp_per_dp < peer_tp_per_dp:
return (peer_tp_rank % ratio) * bytes_cont_heads, 0
else:
return 0, (self_tp_rank % ratio) * bytes_cont_heads
@staticmethod
def _get_layer_kv_num(ri: RankInfo) -> int:
return ri.kv_heads_per_rank * ri.tokens_per_block * ri.dims_per_head * ri.element_bytes
@staticmethod
def _get_frags(bases, layer_indices, layer_kv_num, kv_indices, head_off, kv_factor):
layer_num = layer_kv_num * kv_factor
return (
bases[:, None, None]
+ layer_num * layer_indices[None, :, None]
+ layer_kv_num * kv_indices[None, None, :]
+ head_off
)

View File

@ -0,0 +1,126 @@
from dataclasses import dataclass
from typing import List, Set
import numpy as np
BUFFER_ENTRY_DTYPE = np.dtype(
[
("layer_id", np.uint32),
("role", np.uint32),
("offset", np.uint32),
("size", np.uint32),
]
)
@dataclass
class PoolDescriptor:
"""
Pool descriptor containing memory layout and buffer information
One pool contains multiple buffer entries, each representing a (layer_id, role) combination.
"""
base_address: int # (uint64)
slot_bytes: int
num_slots: int
# Buffer entries: flattened array of all (layer_id, role, offset, size) in this pool
buffer_entries: np.ndarray # dtype=BUFFER_ENTRY_DTYPE
@property
def pool_bytes(self) -> int:
return self.slot_bytes * self.num_slots
@property
def unique_layers(self) -> Set[int]:
return set(int(entry["layer_id"]) for entry in self.buffer_entries)
@property
def unique_roles(self) -> Set[int]:
return set(int(entry["role"]) for entry in self.buffer_entries)
def get_slot_address(self, slot_id: int) -> int:
if slot_id >= self.num_slots:
raise ValueError(f"slot_id {slot_id} >= num_slots {self.num_slots}")
return self.base_address + slot_id * self.slot_bytes
def get_device_pointer(self, slot_id: int, layer_id: int, role_enum: int) -> int:
if slot_id >= self.num_slots:
raise ValueError(f"slot_id {slot_id} >= num_slots {self.num_slots}")
for entry in self.buffer_entries:
if entry["layer_id"] == layer_id and entry["role"] == role_enum:
slot_base = self.base_address + slot_id * self.slot_bytes
return slot_base + int(entry["offset"])
raise ValueError(f"Buffer not found: layer_id={layer_id}, role_enum={role_enum}")
def __repr__(self) -> str:
return (
f"PoolDescriptor(base=0x{self.base_address:x}, "
f"slot_bytes={self.slot_bytes}, num_slots={self.num_slots}, "
f"layers={len(self.unique_layers)}, roles={len(self.unique_roles)}"
)
@dataclass
class KVCachePageTable:
"""
Multi-dimensional KV cache page table
Structure:
KVCachePageTable
PoolGroups (List[List[PoolDescriptor]])
PoolGroup 0: List of PoolDescriptors
Pool 0: PoolDescriptor
Pool 1: PoolDescriptor
PoolGroup 1: List of PoolDescriptors
Pool 0: PoolDescriptor
Pool 1: PoolDescriptor
Relationships:
- pools[pg_idx] = List[PoolDescriptor] (all pools in same PoolGroup)
- All pools in pools[pg_idx] share the same lifecycle
"""
tokens_per_block: int
num_layers: int
pools: List[List[PoolDescriptor]] # pools[pg_idx][pool_idx] → PoolDescriptor
@property
def num_pool_groups(self) -> int:
return len(self.pools)
@property
def total_pools(self) -> int:
return sum(len(pg_pools) for pg_pools in self.pools)
@property
def total_buffer_entries(self) -> int:
return sum(pool.num_buffer_entries for pg_pools in self.pools for pool in pg_pools)
@property
def total_pool_bytes(self) -> int:
return sum(pool.pool_bytes for pg_pools in self.pools for pool in pg_pools)
@property
def total_slots(self) -> int:
return sum(pool.num_slots for pg_pools in self.pools for pool in pg_pools)
def get_pool(self, pg_idx: int, pool_idx: int) -> PoolDescriptor:
return self.pools[pg_idx][pool_idx]
def get_device_pointer(
self, pg_idx: int, pool_idx: int, slot_id: int, layer_id: int, role: str
) -> int:
pool = self.pools[pg_idx][pool_idx]
role_enum = self.role_to_enum(role)
return pool.get_device_pointer(slot_id, layer_id, role_enum)
def __repr__(self) -> str:
return (
f"KVCachePageTable(poolgroups={self.num_pool_groups}, "
f"pools={self.total_pools}, layers={self.num_layers})"
)

View File

@ -0,0 +1,89 @@
from dataclasses import dataclass
from typing import List
from tensorrt_llm._torch.disaggregation.base.region import (
DataLayout,
MemRegionGroup,
RegionExtractorBase,
SpecRegion,
)
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._utils import get_size_in_bytes
@dataclass
class KVPoolAttrs:
"""Attributes for a single (primary) KV memory pool."""
pool_ptrs: List[int]
block_bytes: List[int]
class KVRegionExtractorV1(RegionExtractorBase):
"""
Descriptor and region extractor for KV cache pool managed by KVCacheManager.
Provides region descriptors for adapting block-wise view.
"""
def __init__(self, kv_arg: KVCacheManager | KVPoolAttrs):
if isinstance(kv_arg, KVPoolAttrs):
self._kv_pool_attrs = kv_arg
elif isinstance(kv_arg, KVCacheManager):
self._kv_pool_attrs = self._attrs_from_manager(kv_arg)
else:
raise TypeError(
f"kv_cache_manager must be KVCacheManager or KVPoolAttrs, got {type(kv_arg)}"
)
self._data_layout = DataLayout.HND
@staticmethod
def _attrs_from_manager(manager: KVCacheManager) -> KVPoolAttrs:
try:
pools = manager.get_unique_primary_pool()
except Exception as ex:
raise ValueError(f"Failed to get pool(s): {ex}")
pool_list = list(pools) if isinstance(pools, (list, tuple)) else [pools]
elem_bytes = get_size_in_bytes(1, manager.dtype)
ptrs, block_sizes = [], []
for p in pool_list:
if hasattr(p, "data_ptr") and callable(p.data_ptr):
try:
ptr = int(p.data_ptr())
except Exception as ex:
raise ValueError(f"Fail to call data_ptr(): {ex}")
elif isinstance(p, int):
ptr = int(p)
else:
raise ValueError(f"Pool object lacks 'data_ptr' and is not int: {p!r}")
ptrs.append(ptr)
try:
if hasattr(p, "__getitem__") and hasattr(p[0], "numel"):
n = int(p[0].numel())
elif hasattr(p, "numel") and callable(p.numel):
n = int(p.numel())
else:
raise RuntimeError("Cannot determine element count")
except Exception as ex:
raise ValueError(f"Failed to get block size from {p!r}: {ex}")
block_sizes.append(n * elem_bytes)
return KVPoolAttrs(pool_ptrs=ptrs, block_bytes=block_sizes)
def extract(self, region_ids: List[int]) -> SpecRegion:
"""
Given a list of region_ids, returns a single SpecRegion,
whose memory is a MemRegionGroup containing all blocks described
by region_ids.
"""
assert len(self._kv_pool_attrs.pool_ptrs) == 1
pool_idx = 0
attrs = self._kv_pool_attrs
ptrs = [
attrs.pool_ptrs[pool_idx] + block_id * attrs.block_bytes[0] for block_id in region_ids
]
memory = MemRegionGroup(ptrs=ptrs, bytes_per_region=attrs.block_bytes[0])
return SpecRegion(memory=memory)

View File

@ -0,0 +1,77 @@
from collections import defaultdict
import numpy as np
from tensorrt_llm._torch.disaggregation.native.region.page import (
BUFFER_ENTRY_DTYPE,
KVCachePageTable,
PoolDescriptor,
)
from tensorrt_llm.runtime.kv_cache_manager_v2 import CacheTier, KVCacheManager
def build_page_table(manager: KVCacheManager) -> KVCachePageTable:
storage = manager._storage
config = manager._init_config
gpu_level = 0
for level_idx, cache_tier_config in enumerate(config.cache_tiers):
if cache_tier_config.tier == CacheTier.GPU_MEM:
gpu_level = level_idx
break
buffer_by_pool = defaultdict(list)
lc_to_pg_cache = {}
for buffer_id, attr in storage._buffer_attr.items():
layer_id, role = buffer_id
lc_id = attr.life_cycle_id
if lc_id not in lc_to_pg_cache:
lc_to_pg_cache[lc_id] = storage.get_pool_group_index(lc_id)
pg_idx = lc_to_pg_cache[lc_id]
pool_idx = attr.pool_index
pool_key = (pg_idx, pool_idx)
buffer_by_pool[pool_key].append((layer_id, role, attr.offset, attr.size))
pools = []
num_pool_groups = storage.num_pool_groups
pool_group_storage = storage._levels[gpu_level].storage._pool_groups
for pg_idx in range(num_pool_groups):
pool_group = pool_group_storage[pg_idx]
num_pools = pool_group.num_pools
pg_pools = []
for pool_idx in range(num_pools):
pool = pool_group._pools[pool_idx]
base_address = int(pool.slot_address(0))
slot_bytes = int(pool.slot_size)
num_slots = int(pool.num_slots)
pool_key = (pg_idx, pool_idx)
buffers_info = buffer_by_pool.get(pool_key, [])
if buffers_info:
buffer_entries = np.array(buffers_info, dtype=BUFFER_ENTRY_DTYPE)
else:
buffer_entries = np.array([], dtype=BUFFER_ENTRY_DTYPE)
pool_desc = PoolDescriptor(
base_address=base_address,
slot_bytes=slot_bytes,
num_slots=num_slots,
buffer_entries=buffer_entries,
)
pg_pools.append(pool_desc)
pools.append(pg_pools)
return KVCachePageTable(
tokens_per_block=config.tokens_per_block,
num_layers=len(config.layers),
pools=pools,
)

View File

@ -38,6 +38,10 @@ l0_a10:
- unittest/disaggregated/test_messenger.py
- unittest/disaggregated/test_disagg_cluster_manager_worker.py
- unittest/disaggregated/test_cluster_storage.py
- unittest/disaggregated/test_extractor.py
- unittest/disaggregated/test_extractor_v2.py
- unittest/disaggregated/test_peer.py
- unittest/disaggregated/region/test_block.py
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]

View File

@ -39,6 +39,10 @@ l0_h100:
- unittest/disaggregated/test_remoteDictionary.py
- unittest/disaggregated/test_agent_multi_backends.py
- unittest/disaggregated/test_messenger.py
- unittest/disaggregated/test_extractor.py
- unittest/disaggregated/test_extractor_v2.py
- unittest/disaggregated/test_peer.py
- unittest/disaggregated/region/test_block.py
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse

View File

@ -0,0 +1,146 @@
from tensorrt_llm._torch.disaggregation.base.region import (
MemRegionGroup,
SpecRegion,
SpecRegionPair,
)
from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo
from tensorrt_llm._torch.disaggregation.native.region.block import (
HeadMatchMapper,
HeadMismatchMapper,
IdentityMapper,
)
def make_rankinfo(
kv_heads_per_rank=2,
tokens_per_block=4,
dims_per_head=2,
element_bytes=1,
tp_size=2,
tp_rank=0,
dp_size=1,
dp_rank=0,
pp_size=1,
pp_rank=0,
cp_size=1,
cp_rank=0,
is_mla=False,
):
return RankInfo(
instance_name="rank",
instance_rank=0,
tp_size=tp_size,
tp_rank=tp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
pp_size=pp_size,
pp_rank=pp_rank,
cp_size=cp_size,
cp_rank=cp_rank,
device_id=0,
kv_heads_per_rank=kv_heads_per_rank,
tokens_per_block=tokens_per_block,
dims_per_head=dims_per_head,
element_bytes=element_bytes,
enable_attention_dp=False,
is_mla=is_mla,
layer_num_per_pp=[1],
kv_ptrs=[],
aux_ptrs=[],
server_endpoint="",
self_endpoint="",
transfer_engine_info=b"",
aux_meta=None,
)
def test_mem_region_group():
ptrs = [11, 22, 33]
bytes_per_region = 16
region = MemRegionGroup(ptrs=ptrs, bytes_per_region=bytes_per_region)
assert list(region.ptrs) == ptrs
assert region.bytes_per_region == bytes_per_region
def test_spec_region_and_spec_region_pair():
group_src = MemRegionGroup(ptrs=[101, 202], bytes_per_region=8)
group_dst = MemRegionGroup(ptrs=[303, 404], bytes_per_region=8)
spec_src = SpecRegion(memory=group_src, spec="spec_src")
spec_dst = SpecRegion(memory=group_dst, spec="spec_dst")
assert isinstance(spec_src, SpecRegion)
assert isinstance(spec_dst, SpecRegion)
pair = SpecRegionPair(src=spec_src, dst=spec_dst)
assert isinstance(pair, SpecRegionPair)
assert pair.src.memory.ptrs == [101, 202]
assert pair.dst.memory.ptrs == [303, 404]
assert pair.src.spec == "spec_src"
assert pair.dst.spec == "spec_dst"
def test_identity_mapper():
src_group = MemRegionGroup(ptrs=[100, 200], bytes_per_region=32)
dst_group = MemRegionGroup(ptrs=[300, 400], bytes_per_region=32)
src_spec = SpecRegion(memory=src_group, spec="a")
dst_spec = SpecRegion(memory=dst_group, spec="b")
mapper = IdentityMapper()
result = mapper.map(src_spec, dst_spec)
assert isinstance(result, SpecRegionPair)
assert list(result.src.memory.ptrs) == [100, 200]
assert list(result.dst.memory.ptrs) == [300, 400]
assert result.src.memory.bytes_per_region == 32
assert result.dst.memory.bytes_per_region == 32
def test_head_match_mapper():
self_ri = make_rankinfo(kv_heads_per_rank=2)
peer_ri = make_rankinfo(kv_heads_per_rank=2)
transfer_layers = 2
src_layer_off = 1
dst_layer_off = 1
src_group = MemRegionGroup(ptrs=[10, 20], bytes_per_region=1)
dst_group = MemRegionGroup(ptrs=[30, 40], bytes_per_region=1)
src_spec = SpecRegion(memory=src_group, spec="srcspec")
dst_spec = SpecRegion(memory=dst_group, spec="dstspec")
mapper = HeadMatchMapper(transfer_layers, src_layer_off, dst_layer_off, self_ri, peer_ri)
result = mapper.map(src_spec, dst_spec)
expected_off = (
transfer_layers
* mapper._kv_factor
* self_ri.kv_heads_per_rank
* self_ri.tokens_per_block
* self_ri.dims_per_head
* self_ri.element_bytes
)
assert list(result.src.memory.ptrs) == [10 + mapper._src_block_off, 20 + mapper._src_block_off]
assert list(result.dst.memory.ptrs) == [30 + mapper._dst_block_off, 40 + mapper._dst_block_off]
assert result.src.memory.bytes_per_region == expected_off
assert result.dst.memory.bytes_per_region == expected_off
def test_head_mismatch_mapper():
self_ri = make_rankinfo(kv_heads_per_rank=2, tp_size=2, tp_rank=1)
peer_ri = make_rankinfo(kv_heads_per_rank=4, tp_size=4, tp_rank=2)
transfer_layers = 1
src_layer_off = 0
peer_layer_off = 1
src_group = MemRegionGroup(ptrs=[111], bytes_per_region=32)
dst_group = MemRegionGroup(ptrs=[222], bytes_per_region=32)
src_spec = SpecRegion(memory=src_group, spec="srcspec")
dst_spec = SpecRegion(memory=dst_group, spec="dstspec")
mapper = HeadMismatchMapper(transfer_layers, src_layer_off, peer_layer_off, self_ri, peer_ri)
result = mapper.map(src_spec, dst_spec)
expected_frag_count = self_ri.kv_factor * transfer_layers
assert isinstance(result, SpecRegionPair)
assert len(result.src.memory.ptrs) == expected_frag_count
assert len(result.dst.memory.ptrs) == expected_frag_count
assert all(isinstance(x, int) for x in result.src.memory.ptrs)
assert all(isinstance(x, int) for x in result.dst.memory.ptrs)
assert result.src.memory.bytes_per_region == mapper._bytes_cont_heads
assert result.dst.memory.bytes_per_region == mapper._bytes_cont_heads
def test_rankinfo_kv_factor():
ri1 = make_rankinfo(is_mla=False)
ri2 = make_rankinfo(is_mla=True)
assert ri1.kv_factor == 2
assert ri2.kv_factor == 1

View File

@ -0,0 +1,101 @@
import pytest
from tensorrt_llm._torch.disaggregation.base.region import MemRegionGroup, SpecRegion
from tensorrt_llm._torch.disaggregation.resource.kv_extractor import KVRegionExtractorV1
from tensorrt_llm._torch.pyexecutor.resource_manager import (
CacheTypeCpp,
DataType,
KvCacheConfig,
KVCacheManager,
Mapping,
)
class DummyRankInfo:
instance_name = "dummy"
instance_rank = 0
tp_size = 1
tp_rank = 0
pp_size = 1
pp_rank = 0
dp_size = 1
dp_rank = 0
cp_size = 1
cp_rank = 0
device_id = 0
kv_heads_per_rank = 8
tokens_per_block = 32
dims_per_head = 16
element_bytes = 2
enable_attention_dp = False
is_mla = False
layer_num_per_pp = [1]
@property
def kv_factor(self) -> int:
return 2 if not self.is_mla else 1
@pytest.mark.cuda
def test_extract():
num_layers = 1
num_kv_heads = 8
head_dim = 16
tokens_per_block = 32
max_seq_len = 128
max_batch_size = 2
dtype = DataType.HALF
mapping = Mapping(world_size=1, rank=0, tp_size=1, pp_size=1, gpus_per_node=1)
kv_cache_config = KvCacheConfig(
max_tokens=512,
free_gpu_memory_fraction=0.1,
max_attention_window=None,
enable_block_reuse=False,
event_buffer_max_size=0,
onboard_blocks=0,
host_cache_size=0,
enable_partial_reuse=False,
copy_on_partial_reuse=False,
sink_token_length=0,
max_util_for_resume=1,
)
kv_cache_type = CacheTypeCpp.SELF
manager = KVCacheManager(
kv_cache_config=kv_cache_config,
kv_cache_type=kv_cache_type,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
mapping=mapping,
dtype=dtype,
)
extractor = KVRegionExtractorV1(manager)
region_ids = [0, 1]
spec_region = extractor.extract(region_ids)
assert isinstance(spec_region, SpecRegion)
memory = spec_region.memory
assert isinstance(memory, MemRegionGroup)
assert len(memory.ptrs) == len(region_ids)
assert memory.bytes_per_region > 0
pool_ptrs = manager.get_unique_primary_pool()
if hasattr(pool_ptrs, "__getitem__"):
if hasattr(pool_ptrs[0], "data_ptr"):
pool_base_ptr = int(pool_ptrs[0].data_ptr())
else:
pool_base_ptr = int(pool_ptrs[0])
else:
pool_base_ptr = (
int(pool_ptrs.data_ptr()) if hasattr(pool_ptrs, "data_ptr") else int(pool_ptrs)
)
expected_block_bytes = memory.bytes_per_region
expected_ptrs = [pool_base_ptr + block_id * expected_block_bytes for block_id in region_ids]
assert list(memory.ptrs) == expected_ptrs
manager.shutdown()

View File

@ -0,0 +1,78 @@
import pytest
from tensorrt_llm._torch.disaggregation.resource.kv_extractor_v2 import build_page_table
from tensorrt_llm.runtime.kv_cache_manager_v2 import (
AttentionLayerConfig,
BufferConfig,
GpuCacheTierConfig,
KVCacheManager,
KVCacheManagerConfig,
)
@pytest.fixture
def simple_manager():
layers = [
AttentionLayerConfig(
layer_id=0,
sliding_window_size=None,
num_sink_tokens=0,
buffers=[
BufferConfig(role=0, size=8192),
BufferConfig(role=1, size=8192),
],
),
AttentionLayerConfig(
layer_id=1,
sliding_window_size=None,
num_sink_tokens=0,
buffers=[
BufferConfig(role=0, size=8192),
BufferConfig(role=1, size=8192),
],
),
]
cache_tiers = [
GpuCacheTierConfig(
quota=100 * 1024 * 1024, # 100MB
),
]
config = KVCacheManagerConfig(
tokens_per_block=64,
layers=layers,
vocab_size=50257,
cache_tiers=cache_tiers,
)
return KVCacheManager(config)
def test_build_page_table(simple_manager):
page_table = build_page_table(simple_manager)
# Check basic properties
assert page_table.tokens_per_block == 64
assert page_table.num_layers == 2
assert page_table.num_pool_groups >= 1
assert page_table.total_pools > 0
# Check pools are created
assert len(page_table.pools) > 0
assert all(len(pg_pools) > 0 for pg_pools in page_table.pools)
# Check first pool has valid properties
pool = page_table.pools[0][0]
assert pool.base_address > 0
assert pool.slot_bytes > 0
assert pool.num_slots > 0
assert pool.pool_bytes == pool.slot_bytes * pool.num_slots
# Check buffer entries exist
assert len(pool.buffer_entries) > 0
print(f"\n Page table: {page_table}")
print(f" Total pools: {page_table.total_pools}")
print(f" Pools: {page_table.pools}")
print(f" Total size: {page_table.total_pool_bytes / (1024**2):.2f} MB")

View File

@ -0,0 +1,297 @@
import pytest
from tensorrt_llm._torch.disaggregation.native.peer import PeerOverlap, PeerRegistrar, RankInfo
from tensorrt_llm._torch.disaggregation.native.region.block import (
HeadMatchMapper,
HeadMismatchMapper,
IdentityMapper,
)
from tensorrt_llm._torch.disaggregation.resource.kv_extractor import (
KVPoolAttrs,
KVRegionExtractorV1,
)
def make_rankinfo(
instance_name="self",
instance_rank=0,
tp_size=2,
tp_rank=0,
pp_size=1,
pp_rank=0,
dp_size=1,
dp_rank=0,
cp_size=1,
cp_rank=0,
kv_heads_per_rank=2,
tokens_per_block=16,
dims_per_head=8,
element_bytes=2,
is_mla=False,
enable_attention_dp=False,
layer_num_per_pp=None,
):
if layer_num_per_pp is None:
layer_num_per_pp = [2] * pp_size
return RankInfo(
instance_name=instance_name,
instance_rank=instance_rank,
tp_size=tp_size,
tp_rank=tp_rank,
pp_size=pp_size,
pp_rank=pp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
cp_size=cp_size,
cp_rank=cp_rank,
device_id=0,
kv_heads_per_rank=kv_heads_per_rank,
tokens_per_block=tokens_per_block,
dims_per_head=dims_per_head,
element_bytes=element_bytes,
enable_attention_dp=enable_attention_dp,
is_mla=is_mla,
layer_num_per_pp=layer_num_per_pp,
kv_ptrs=[],
aux_ptrs=[],
server_endpoint="",
self_endpoint="",
transfer_engine_info=b"",
aux_meta=None,
)
def _make_peer_registrar_and_peer_ri(self_ri, peer_ri):
pool_attrs = KVPoolAttrs(pool_ptrs=[1234], block_bytes=[1024])
real_kv_extractor = KVRegionExtractorV1(pool_attrs)
reg = PeerRegistrar(self_ri, real_kv_extractor)
return reg, peer_ri
def test_basic_overlap():
self_ri = make_rankinfo(
"self",
pp_size=1,
pp_rank=0,
tp_size=1,
tp_rank=0,
cp_size=1,
cp_rank=0,
layer_num_per_pp=[2],
)
peer_ri = make_rankinfo(
"peer",
pp_size=2,
pp_rank=0,
tp_size=1,
tp_rank=0,
cp_size=1,
cp_rank=0,
layer_num_per_pp=[1, 1],
)
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
overlap = reg.get_peer_overlap(peer_ri, peer_dp_rank=0)
assert overlap.overlap_pp_size == 2
assert overlap.target_peer_pp_layer_num == [1, 1]
assert overlap.overlap_tp_size == 1
assert overlap.ranks == [0, 1]
def test_no_overlap():
self_ri = make_rankinfo(
"self",
pp_size=1,
pp_rank=0,
tp_size=1,
tp_rank=0,
cp_size=1,
cp_rank=0,
layer_num_per_pp=[0],
)
peer_ri = make_rankinfo(
"peer",
pp_size=1,
pp_rank=0,
tp_size=1,
tp_rank=0,
cp_size=1,
cp_rank=0,
layer_num_per_pp=[2],
)
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
overlap = reg.get_peer_overlap(peer_ri, 0)
assert overlap.overlap_pp_size == 0
assert overlap.target_peer_pp_layer_num == []
assert overlap.ranks == []
def test_pp_ratio_peer_smaller():
self_ri = make_rankinfo(
"self",
pp_size=2,
pp_rank=1,
tp_size=1,
tp_rank=0,
cp_size=1,
cp_rank=0,
layer_num_per_pp=[1, 2],
)
peer_ri = make_rankinfo(
"peer",
pp_size=1,
pp_rank=0,
tp_size=1,
tp_rank=0,
cp_size=1,
cp_rank=0,
layer_num_per_pp=[3],
)
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
overlap = reg.get_peer_overlap(peer_ri, 0)
assert overlap.overlap_pp_size > 0
assert sum(overlap.target_peer_pp_layer_num) > 0
assert all(r >= 0 for r in overlap.ranks)
def test_tp_overlap():
self_ri = make_rankinfo(
"self", tp_size=2, tp_rank=1, pp_size=1, pp_rank=0, cp_size=1, cp_rank=0
)
peer_ri = make_rankinfo(
"peer", tp_size=4, tp_rank=0, pp_size=1, pp_rank=0, cp_size=1, cp_rank=0
)
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
overlap = reg.get_peer_overlap(peer_ri, 0)
assert overlap.overlap_tp_size in [1, 2]
assert all(isinstance(r, int) for r in overlap.ranks)
def test_cp_overlap():
self_ri = make_rankinfo(
"self", cp_size=2, cp_rank=1, pp_size=1, pp_rank=0, tp_size=1, tp_rank=0
)
peer_ri = make_rankinfo(
"peer", cp_size=4, cp_rank=0, pp_size=1, pp_rank=0, tp_size=1, tp_rank=0
)
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
overlap = reg.get_peer_overlap(peer_ri, 0)
assert overlap.overlap_cp_size in [1, 2]
assert all(isinstance(r, int) for r in overlap.ranks)
def test_multiple_overlap():
self_ri = make_rankinfo(
"self",
pp_size=2,
pp_rank=1,
tp_size=2,
tp_rank=1,
cp_size=2,
cp_rank=1,
layer_num_per_pp=[1, 2],
)
peer_ri = make_rankinfo(
"peer",
pp_size=4,
pp_rank=2,
tp_size=4,
tp_rank=3,
cp_size=4,
cp_rank=0,
layer_num_per_pp=[1, 1, 1, 1],
)
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
overlap = reg.get_peer_overlap(peer_ri, peer_dp_rank=0)
expected_overlap = PeerOverlap(
overlap_pp_size=2,
overlap_tp_size=2,
overlap_cp_size=2,
duplicate_head_factor=1,
peer_duplicate_head_factor=2,
target_peer_pp_layer_num=[1, 1],
ranks=[26, 27, 30, 31, 42, 43, 46, 47],
)
assert overlap == expected_overlap
def _make_peer_registrar(self_rankinfo):
pool_attrs = KVPoolAttrs(pool_ptrs=[1234], block_bytes=[1024])
real_kv_extractor = KVRegionExtractorV1(pool_attrs)
reg = PeerRegistrar(self_rankinfo, real_kv_extractor)
return reg
def test_peer_registrar_register_and_get():
self_rankinfo = make_rankinfo(instance_name="local")
reg = _make_peer_registrar(self_rankinfo)
peer_ri = make_rankinfo(instance_name="peer", instance_rank=1, layer_num_per_pp=[2])
reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri)
assert reg.get_peer_rank_info("peer", 1) == peer_ri
def test_peer_registrar_unregister():
self_rankinfo = make_rankinfo(instance_name="local")
reg = _make_peer_registrar(self_rankinfo)
peer_ri = make_rankinfo(instance_name="peer", instance_rank=1, layer_num_per_pp=[2])
reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri)
reg.unregister(peer_ri.instance_name, peer_ri.instance_rank)
with pytest.raises(KeyError):
reg.get_peer_rank_info("peer", 1)
def test_peer_registrar_incompatible_peer_raises():
self_rankinfo = make_rankinfo(instance_name="local")
reg = _make_peer_registrar(self_rankinfo)
peer_ri = make_rankinfo(instance_name="peer", instance_rank=3, is_mla=True)
with pytest.raises(ValueError):
reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri)
def test_peer_registrar_self_rank_info_property():
self_rankinfo = make_rankinfo(instance_name="local")
reg = _make_peer_registrar(self_rankinfo)
assert reg.self_rank_info == self_rankinfo
def test_peer_registrar_get_kv_map_identity():
self_rankinfo = make_rankinfo(instance_name="local")
reg = _make_peer_registrar(self_rankinfo)
peer_ri = make_rankinfo(instance_name="peer", instance_rank=1, layer_num_per_pp=[2])
reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri)
mapper = reg.get_kv_map(peer_ri)
assert isinstance(mapper, IdentityMapper)
def test_peer_registrar_get_kv_map_head_match():
self_rankinfo = make_rankinfo(instance_name="local")
reg = _make_peer_registrar(self_rankinfo)
peer_ri = make_rankinfo(
instance_name="peer",
instance_rank=2,
pp_size=2,
pp_rank=1,
layer_num_per_pp=[1, 1],
tokens_per_block=16,
dims_per_head=8,
)
mapper = reg.get_kv_map(peer_ri)
assert isinstance(mapper, HeadMatchMapper)
def test_peer_registrar_get_kv_map_head_mismatch():
self_rankinfo = make_rankinfo(instance_name="local")
reg = _make_peer_registrar(self_rankinfo)
peer_ri = make_rankinfo(
instance_name="peer",
instance_rank=3,
pp_size=1,
pp_rank=0,
tp_size=1,
tp_rank=0,
kv_heads_per_rank=4,
tokens_per_block=16,
dims_per_head=8,
layer_num_per_pp=[2],
)
mapper = reg.get_kv_map(peer_ri)
assert isinstance(mapper, HeadMismatchMapper)