mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[TRTLLM-9527][feat] Modularization of the transceiver for KV manager v2 (step 4) (#11225)
Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
66caa67357
commit
b1268e1b37
124
tensorrt_llm/_torch/disaggregation/base/region.py
Normal file
124
tensorrt_llm/_torch/disaggregation/base/region.py
Normal file
@ -0,0 +1,124 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import IntFlag, auto
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IndexRange:
|
||||
"""
|
||||
Represents a closed interval [start, end], with both bounds >= 0.
|
||||
Commonly used for indexing layers, heads, or tokens.
|
||||
"""
|
||||
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def __post_init__(self):
|
||||
if not (isinstance(self.start, int) and isinstance(self.end, int)):
|
||||
raise TypeError("start and end must be integers")
|
||||
if self.start < 0 or self.end < 0:
|
||||
raise ValueError("start and end must be >= 0")
|
||||
if self.end < self.start:
|
||||
raise ValueError("end must be >= start")
|
||||
|
||||
|
||||
class MemRegion(NamedTuple):
|
||||
"""Describes a block of memory by starting pointer and size in bytes."""
|
||||
|
||||
ptr: int
|
||||
bytes: int
|
||||
|
||||
|
||||
class MemRegionGroup(NamedTuple):
|
||||
"""Describes a block of memory by starting pointer and size in bytes."""
|
||||
|
||||
ptrs: List[int]
|
||||
bytes_per_region: int
|
||||
|
||||
|
||||
class DataRole(IntFlag):
|
||||
"""Logical role(s) a memory region plays. Supports combinations."""
|
||||
|
||||
KEY = auto()
|
||||
VALUE = auto()
|
||||
|
||||
|
||||
class DataLayout(IntFlag):
|
||||
"""Possible orders for storing data in memory."""
|
||||
|
||||
HND = auto() # (head, seq_len, dim)
|
||||
NHD = auto() # (seq_len, head, dim)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegionSpec:
|
||||
"""
|
||||
Specifies a (potentially partial) region of the cache.
|
||||
Extend this base class for additional axis or specialization.
|
||||
"""
|
||||
|
||||
layers: Optional[IndexRange] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KVRegionSpec(RegionSpec):
|
||||
"""
|
||||
Specifies a region within the Key/Value cache, with optional axes.
|
||||
"""
|
||||
|
||||
role: DataRole = DataRole.KEY | DataRole.VALUE
|
||||
heads: Optional[IndexRange] = None
|
||||
tokens: Optional[IndexRange] = None
|
||||
|
||||
|
||||
class SpecRegion(NamedTuple):
|
||||
"""
|
||||
Associates a memory region with its semantic specifier.
|
||||
"""
|
||||
|
||||
memory: MemRegion | MemRegionGroup
|
||||
spec: RegionSpec = None
|
||||
|
||||
|
||||
class RegionExtractorBase(ABC):
|
||||
"""
|
||||
Interface for extracting region descriptors from some backing store.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract(self, region_ids: Optional[List[int]] = None) -> List[SpecRegion]:
|
||||
"""
|
||||
Args:
|
||||
region_ids: (Optional) List of integer region identifiers to extract.
|
||||
Returns:
|
||||
List of Regions for corresponding regions.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SpecRegionPair(NamedTuple):
|
||||
"""
|
||||
Maps a source descriptor to a destination descriptor
|
||||
(e.g., when copying or reindexing regions).
|
||||
"""
|
||||
|
||||
src: SpecRegion
|
||||
dst: SpecRegion
|
||||
|
||||
|
||||
class RegionMapperBase(ABC):
|
||||
"""
|
||||
Maps a batch of region descriptors to corresponding destination(s).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair:
|
||||
"""
|
||||
Args:
|
||||
src_regions: List of source Regions.
|
||||
dst_regions: List of destination Regions.
|
||||
Returns:
|
||||
List of RegionPairs mapping source to destination.
|
||||
"""
|
||||
...
|
||||
286
tensorrt_llm/_torch/disaggregation/native/peer.py
Normal file
286
tensorrt_llm/_torch/disaggregation/native/peer.py
Normal file
@ -0,0 +1,286 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo
|
||||
from tensorrt_llm._torch.disaggregation.native.region.block import (
|
||||
HeadMatchMapper,
|
||||
HeadMismatchMapper,
|
||||
IdentityMapper,
|
||||
RegionMapperBase,
|
||||
)
|
||||
from tensorrt_llm._torch.disaggregation.resource.kv_extractor import (
|
||||
KVPoolAttrs,
|
||||
KVRegionExtractorV1,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PeerOverlap:
|
||||
overlap_pp_size: int = 0
|
||||
overlap_tp_size: int = 0
|
||||
overlap_cp_size: int = 0
|
||||
duplicate_head_factor: int = 1
|
||||
peer_duplicate_head_factor: int = 1
|
||||
target_peer_pp_layer_num: List[int] = field(default_factory=list)
|
||||
ranks: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class PeerRegistrar:
|
||||
def __init__(self, self_rank_info: RankInfo, self_extractor: KVRegionExtractorV1):
|
||||
self._ri = self_rank_info
|
||||
self._peer_ri_cache: Dict[str, RankInfo] = {}
|
||||
self._kv_map_cache: Dict[str, RegionMapperBase] = {}
|
||||
self._self_ext_cache = self_extractor
|
||||
self._peer_ext_cache: Dict[str, KVRegionExtractorV1] = {}
|
||||
self._overlap_cache: Dict[str, PeerOverlap] = {}
|
||||
|
||||
def _block_size(self, layer_num: int, ri: RankInfo) -> int:
|
||||
return (
|
||||
layer_num
|
||||
* ri.kv_factor
|
||||
* ri.kv_heads_per_rank
|
||||
* ri.tokens_per_block
|
||||
* ri.dims_per_head
|
||||
* ri.element_bytes
|
||||
)
|
||||
|
||||
def register(self, peer_name: str, peer_rank: int, peer_ri: RankInfo):
|
||||
# TODO: check if peer is valid for registration
|
||||
assert self._self_ext_cache is not None
|
||||
if not self._check_peer_compatible(peer_ri):
|
||||
raise ValueError(
|
||||
f"PeerRegistrar.register: peer {peer_name} (rank={peer_rank}) is incompatible with local rank."
|
||||
)
|
||||
key = self._unique_key(peer_name, peer_rank)
|
||||
self._peer_ri_cache[key] = peer_ri
|
||||
peer_ri = self.get_peer_rank_info(peer_name, peer_rank)
|
||||
layer_num = peer_ri.layer_num_per_pp[peer_ri.pp_rank]
|
||||
block_size = self._block_size(layer_num, peer_ri)
|
||||
extractor = KVRegionExtractorV1(
|
||||
KVPoolAttrs(pool_ptrs=peer_ri.kv_ptrs, block_bytes=[block_size])
|
||||
)
|
||||
self._peer_ext_cache[key] = extractor
|
||||
|
||||
def peer_extractor(self, peer_name: str, peer_rank: int) -> KVRegionExtractorV1:
|
||||
return self._peer_ext_cache[self._unique_key(peer_name, peer_rank)]
|
||||
|
||||
@property
|
||||
def self_extractor(self) -> KVRegionExtractorV1:
|
||||
assert self._self_ext_cache is not None
|
||||
return self._self_ext_cache
|
||||
|
||||
def unregister(self, peer_name: str, peer_rank: int):
|
||||
key = self._unique_key(peer_name, peer_rank)
|
||||
if key in self._peer_ri_cache:
|
||||
del self._peer_ri_cache[key]
|
||||
if key in self._peer_ext_cache:
|
||||
del self._peer_ext_cache[key]
|
||||
if key in self._kv_map_cache:
|
||||
del self._kv_map_cache[key]
|
||||
|
||||
def get_peer_rank_info(self, peer_name: str, peer_rank: int):
|
||||
return self._peer_ri_cache[self._unique_key(peer_name, peer_rank)]
|
||||
|
||||
@property
|
||||
def self_rank_info(self) -> RankInfo:
|
||||
return self._ri
|
||||
|
||||
def _unique_key(self, name: str, rank: int) -> str:
|
||||
return name + str(rank)
|
||||
|
||||
def _check_peer_compatible(self, peer_ri: RankInfo) -> bool:
|
||||
if self._ri.is_mla != peer_ri.is_mla:
|
||||
logger.warning(
|
||||
"PeerRegistrar: compatibility check failed: 'is_mla' differs "
|
||||
f"(local={self._ri.is_mla}, peer={peer_ri.is_mla})."
|
||||
)
|
||||
return False
|
||||
if self._ri.cp_size != 1 or peer_ri.cp_size != 1:
|
||||
logger.warning(
|
||||
"PeerRegistrar: unsupported configuration: context parallelism (cp_size) "
|
||||
f"must be 1 for both local and peer ranks (local={self._ri.cp_size}, peer={peer_ri.cp_size})."
|
||||
)
|
||||
return False
|
||||
if self._ri.element_bytes != peer_ri.element_bytes:
|
||||
logger.warning(
|
||||
"PeerRegistrar: element size mismatch "
|
||||
f"(local={self._ri.element_bytes} bytes, peer={peer_ri.element_bytes} bytes)."
|
||||
)
|
||||
return False
|
||||
if self._ri.tokens_per_block != peer_ri.tokens_per_block:
|
||||
logger.warning(
|
||||
"PeerRegistrar: tokens_per_block mismatch "
|
||||
f"(local={self._ri.tokens_per_block}, peer={peer_ri.tokens_per_block})."
|
||||
)
|
||||
return False
|
||||
if self._ri.dims_per_head != peer_ri.dims_per_head:
|
||||
logger.warning(
|
||||
"PeerRegistrar: dims_per_head mismatch "
|
||||
f"(local={self._ri.dims_per_head}, peer={peer_ri.dims_per_head})."
|
||||
)
|
||||
return False
|
||||
|
||||
self_layers = sum(self._ri.layer_num_per_pp)
|
||||
peer_layers = sum(peer_ri.layer_num_per_pp)
|
||||
if self_layers != peer_layers:
|
||||
logger.warning(
|
||||
"PeerRegistrar: total layer count mismatch "
|
||||
f"(local={self_layers}, peer={peer_layers})."
|
||||
)
|
||||
return False
|
||||
|
||||
if self._ri.is_mla:
|
||||
if peer_ri.kv_heads_per_rank != 1 or self._ri.kv_heads_per_rank != 1:
|
||||
logger.warning(
|
||||
"PeerRegistrar: MLA mode requires exactly 1 KV head per rank for both local and peer."
|
||||
f" (local={self._ri.kv_heads_per_rank}, peer={peer_ri.kv_heads_per_rank})"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _tp_per_dp(self, info: RankInfo) -> int:
|
||||
return (
|
||||
info.tp_size // info.dp_size
|
||||
if getattr(info, "enable_attention_dp", False)
|
||||
else info.tp_size
|
||||
)
|
||||
|
||||
def get_kv_map(self, peer_ri: RankInfo):
|
||||
key = self._unique_key(peer_ri.instance_name, peer_ri.instance_rank)
|
||||
if key in self._kv_map_cache:
|
||||
return self._kv_map_cache[key]
|
||||
|
||||
self_tp_per_dp = self._tp_per_dp(self._ri)
|
||||
peer_tp_per_dp = self._tp_per_dp(peer_ri)
|
||||
|
||||
is_dup_head = (
|
||||
self._ri.kv_heads_per_rank * self_tp_per_dp
|
||||
!= peer_ri.kv_heads_per_rank * peer_tp_per_dp
|
||||
)
|
||||
head_match = is_dup_head or self._ri.is_mla or self_tp_per_dp == peer_tp_per_dp
|
||||
logger.debug(
|
||||
"KVMapperFactory.get_kv_map: "
|
||||
f"head_match={head_match}, is_dup_head={is_dup_head}, self_is_mla={self._ri.is_mla}, "
|
||||
f"self_tp_per_dp={self_tp_per_dp}, peer_tp_per_dp={peer_tp_per_dp}"
|
||||
)
|
||||
# fast identity when write_all and same pp_size
|
||||
if head_match and self._ri.pp_size == peer_ri.pp_size:
|
||||
mapper = IdentityMapper()
|
||||
self._kv_map_cache[key] = mapper
|
||||
return mapper
|
||||
|
||||
# compute overlapping layers
|
||||
self_start_layer = sum(self._ri.layer_num_per_pp[: self._ri.pp_rank])
|
||||
self_end_layer = self_start_layer + self._ri.layer_num_per_pp[self._ri.pp_rank]
|
||||
peer_start_layer = sum(peer_ri.layer_num_per_pp[: peer_ri.pp_rank])
|
||||
peer_end_layer = peer_start_layer + peer_ri.layer_num_per_pp[peer_ri.pp_rank]
|
||||
start = max(self_start_layer, peer_start_layer)
|
||||
end = min(self_end_layer, peer_end_layer)
|
||||
transfer_layers = end - start
|
||||
self_layer_offset = start - self_start_layer
|
||||
peer_layer_offset = start - peer_start_layer
|
||||
|
||||
if head_match:
|
||||
mapper = HeadMatchMapper(
|
||||
transfer_layers=transfer_layers,
|
||||
src_layer_off=self_layer_offset, # local layer offset
|
||||
dst_layer_off=peer_layer_offset, # peer layer offset
|
||||
self_ri=self._ri,
|
||||
peer_ri=peer_ri,
|
||||
)
|
||||
self._kv_map_cache[key] = mapper
|
||||
return mapper
|
||||
|
||||
# head mismatch case
|
||||
mapper = HeadMismatchMapper(
|
||||
transfer_layers=transfer_layers,
|
||||
src_layer_off=self_layer_offset,
|
||||
peer_layer_off=peer_layer_offset,
|
||||
self_ri=self._ri,
|
||||
peer_ri=peer_ri,
|
||||
)
|
||||
self._kv_map_cache[key] = mapper
|
||||
return mapper
|
||||
|
||||
@staticmethod
|
||||
def _find_overlap(self_val, peer_val, self_rank, peer_rank=None):
|
||||
if self_val <= peer_val:
|
||||
overlap = peer_val // self_val
|
||||
start = self_rank * overlap + (peer_rank * peer_val if peer_rank is not None else 0)
|
||||
end = start + overlap
|
||||
else:
|
||||
ratio = self_val // peer_val
|
||||
start = (self_rank // ratio) + (peer_rank * peer_val if peer_rank is not None else 0)
|
||||
overlap = 1
|
||||
end = start + overlap
|
||||
|
||||
return overlap, start, end
|
||||
|
||||
def get_peer_overlap(self, peer_rank_info: RankInfo, peer_dp_rank: int) -> PeerOverlap:
|
||||
peer_ri = peer_rank_info
|
||||
key = self._unique_key(peer_ri.instance_name, peer_dp_rank)
|
||||
if key in self._overlap_cache:
|
||||
return self._overlap_cache[key]
|
||||
|
||||
# compute pp overlap and target layers
|
||||
self_start_layer = sum(self._ri.layer_num_per_pp[: self._ri.pp_rank])
|
||||
self_end_layer = self_start_layer + self._ri.layer_num_per_pp[self._ri.pp_rank]
|
||||
|
||||
pre = 0
|
||||
tgt_pp_ranks: List[int] = []
|
||||
tgt_pp_layer_num: List[int] = []
|
||||
for p in range(peer_ri.pp_size):
|
||||
peer_start_layer = pre
|
||||
peer_end_layer = peer_start_layer + peer_ri.layer_num_per_pp[p]
|
||||
if self_start_layer < peer_end_layer and self_end_layer > peer_start_layer:
|
||||
tgt_pp_ranks.append(p)
|
||||
tgt_pp_layer_num.append(
|
||||
min(peer_end_layer, self_end_layer) - max(peer_start_layer, self_start_layer)
|
||||
)
|
||||
pre += peer_ri.layer_num_per_pp[p]
|
||||
|
||||
if tgt_pp_ranks == []:
|
||||
# no overlap found
|
||||
targets = PeerOverlap()
|
||||
self._overlap_cache[key] = targets
|
||||
return targets
|
||||
|
||||
peer_start_pp = tgt_pp_ranks[0]
|
||||
overlap_pp_size = len(tgt_pp_ranks)
|
||||
peer_end_pp = peer_start_pp + overlap_pp_size
|
||||
|
||||
# tp per dp-group
|
||||
self_tp_per_dp = self._tp_per_dp(self._ri)
|
||||
peer_tp_per_dp = self._tp_per_dp(peer_ri)
|
||||
self_tp_rank_in_dp = self._ri.tp_rank % self_tp_per_dp
|
||||
|
||||
overlap_tp_size, peer_start_tp, peer_end_tp = self._find_overlap(
|
||||
self_tp_per_dp, peer_tp_per_dp, self_tp_rank_in_dp, peer_dp_rank
|
||||
)
|
||||
overlap_cp_size, peer_start_cp, peer_end_cp = self._find_overlap(
|
||||
self._ri.cp_size, peer_ri.cp_size, self._ri.cp_rank
|
||||
)
|
||||
|
||||
ranks: List[int] = []
|
||||
for pp in range(peer_start_pp, peer_end_pp):
|
||||
for cp in range(peer_start_cp, peer_end_cp):
|
||||
for tp in range(peer_start_tp, peer_end_tp):
|
||||
ranks.append(pp * peer_ri.tp_size * peer_ri.cp_size + cp * peer_ri.tp_size + tp)
|
||||
|
||||
factor_self = self._ri.kv_heads_per_rank * self_tp_per_dp
|
||||
factor_peer = peer_ri.kv_heads_per_rank * peer_tp_per_dp
|
||||
dup_head = max(1, factor_self // factor_peer)
|
||||
peer_dup_head = max(1, factor_peer // factor_self)
|
||||
|
||||
targets = PeerOverlap(
|
||||
overlap_pp_size=overlap_pp_size,
|
||||
overlap_tp_size=overlap_tp_size,
|
||||
overlap_cp_size=overlap_cp_size,
|
||||
duplicate_head_factor=dup_head,
|
||||
peer_duplicate_head_factor=peer_dup_head,
|
||||
target_peer_pp_layer_num=tgt_pp_layer_num,
|
||||
ranks=ranks,
|
||||
)
|
||||
self._overlap_cache[key] = targets
|
||||
return targets
|
||||
75
tensorrt_llm/_torch/disaggregation/native/rank_info.py
Normal file
75
tensorrt_llm/_torch/disaggregation/native/rank_info.py
Normal file
@ -0,0 +1,75 @@
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import msgpack
|
||||
|
||||
from tensorrt_llm._torch.disaggregation.native.region.aux import AuxBufferMeta
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstanceInfo:
|
||||
instance_name: str
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
dp_size: int
|
||||
cp_size: int
|
||||
kv_heads_per_rank: int
|
||||
tokens_per_block: int
|
||||
dims_per_head: int
|
||||
element_bytes: int
|
||||
enable_attention_dp: bool
|
||||
is_mla: bool
|
||||
layer_num_per_pp: List[int]
|
||||
sender_endpoints: List[str]
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return msgpack.packb(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> "InstanceInfo":
|
||||
return cls(**msgpack.unpackb(data))
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankInfo:
|
||||
instance_name: str
|
||||
instance_rank: int
|
||||
tp_size: int
|
||||
tp_rank: int
|
||||
pp_size: int
|
||||
pp_rank: int
|
||||
dp_size: int
|
||||
dp_rank: int
|
||||
cp_size: int
|
||||
cp_rank: int
|
||||
device_id: int
|
||||
kv_heads_per_rank: int
|
||||
# [numLayers, kv_factor, heads, tokens, dims_per_head]
|
||||
tokens_per_block: int
|
||||
dims_per_head: int
|
||||
element_bytes: int
|
||||
enable_attention_dp: bool
|
||||
is_mla: bool
|
||||
layer_num_per_pp: List[int]
|
||||
kv_ptrs: List[int]
|
||||
aux_ptrs: List[int]
|
||||
server_endpoint: str
|
||||
self_endpoint: str
|
||||
transfer_engine_info: bytes
|
||||
aux_meta: Optional[AuxBufferMeta]
|
||||
|
||||
@property
|
||||
def kv_factor(self) -> int:
|
||||
return 2 if not self.is_mla else 1
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
data = asdict(self)
|
||||
data["aux_meta"] = self.aux_meta.to_dict() if self.aux_meta is not None else None
|
||||
return msgpack.packb(data)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> "RankInfo":
|
||||
unpacked = msgpack.unpackb(data)
|
||||
if unpacked.get("aux_meta") is not None:
|
||||
unpacked["aux_meta"] = AuxBufferMeta.from_dict(unpacked["aux_meta"])
|
||||
return cls(**unpacked)
|
||||
183
tensorrt_llm/_torch/disaggregation/native/region/aux.py
Normal file
183
tensorrt_llm/_torch/disaggregation/native/region/aux.py
Normal file
@ -0,0 +1,183 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuxBufferMeta:
|
||||
ptrs: list[int]
|
||||
size: list[int]
|
||||
item_sizes: list[int] = field(default_factory=list)
|
||||
device: str = "cpu"
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"ptrs": self.ptrs,
|
||||
"size": self.size,
|
||||
"item_sizes": self.item_sizes,
|
||||
"device": self.device,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "AuxBufferMeta":
|
||||
return cls(
|
||||
ptrs=data["ptrs"],
|
||||
size=data["size"],
|
||||
item_sizes=data.get("item_sizes", []),
|
||||
device=data.get("device", "cpu"),
|
||||
)
|
||||
|
||||
|
||||
class AuxBufferBase(ABC):
|
||||
"""
|
||||
Abstract base class defining the interface for auxiliary buffer management.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def alloc_slot(self) -> int:
|
||||
"""
|
||||
Allocate a free slot and return its index.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def free_slot(self, slot: int) -> None:
|
||||
"""
|
||||
Release the specified slot.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def meta(self) -> AuxBufferMeta:
|
||||
"""
|
||||
Retrieve meta-information about the underlying buffer(s).
|
||||
Returns buffer info (e.g., pointers, sizes, device).
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def fill_slot(self, slot: int, request: LlmRequest) -> None:
|
||||
"""
|
||||
Fill/overwrite the contents of the given slot with data from the request.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_slot_tokens(self, slot: int) -> tuple[list[int], list[int]]:
|
||||
"""
|
||||
Get the token data (e.g., first/draft tokens) from the specified slot.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AuxBuffer(AuxBufferBase):
|
||||
def __init__(self, max_slot_num: int, beam_width: int, max_draft_len: int, device: str = "cpu"):
|
||||
# public constructor args remain the same, internals are private
|
||||
self._max_slot_num = int(max_slot_num)
|
||||
self._beam_width = int(beam_width)
|
||||
self._max_draft_len = int(max_draft_len)
|
||||
self._device = device
|
||||
|
||||
self._free_slots = deque(list(range(self._max_slot_num)))
|
||||
self._occupied_slots: set[int] = set()
|
||||
self._slot_token_counts: dict[
|
||||
int, tuple[int, int]
|
||||
] = {} # slot -> (first_tokens_len, draft_tokens_len)
|
||||
|
||||
data_type = torch.int32
|
||||
self._first_tokens_buffer = torch.empty(
|
||||
self._max_slot_num, self._beam_width, dtype=data_type, device=self._device
|
||||
)
|
||||
|
||||
self._draft_tokens_buffer = torch.empty(
|
||||
self._max_slot_num, self._max_draft_len, dtype=data_type, device=self._device
|
||||
)
|
||||
|
||||
self._meta = AuxBufferMeta(
|
||||
ptrs=[self._first_tokens_buffer.data_ptr(), self._draft_tokens_buffer.data_ptr()],
|
||||
size=[
|
||||
self._first_tokens_buffer.numel() * self._first_tokens_buffer.element_size(),
|
||||
self._draft_tokens_buffer.numel() * self._draft_tokens_buffer.element_size(),
|
||||
],
|
||||
item_sizes=[
|
||||
self._first_tokens_buffer[0].numel() * self._first_tokens_buffer.element_size(),
|
||||
self._draft_tokens_buffer[0].numel() * self._draft_tokens_buffer.element_size(),
|
||||
],
|
||||
device=self._device,
|
||||
)
|
||||
|
||||
def alloc_slot(self) -> int:
|
||||
if not self._free_slots:
|
||||
raise ValueError(
|
||||
f"No free auxiliary buffer slots available (max slots = {self._max_slot_num}). "
|
||||
"All slots are currently occupied."
|
||||
)
|
||||
slot_id = self._free_slots.popleft()
|
||||
if slot_id in self._occupied_slots:
|
||||
# This should not happen — defensive check.
|
||||
raise RuntimeError(
|
||||
f"Invariant error: selected slot {slot_id} is already marked as occupied. "
|
||||
"This indicates a bug in slot management."
|
||||
)
|
||||
self._occupied_slots.add(slot_id)
|
||||
return slot_id
|
||||
|
||||
def free_slot(self, slot: int) -> None:
|
||||
if slot not in self._occupied_slots:
|
||||
raise ValueError(
|
||||
f"Attempted to free slot {slot}, but that slot is not currently allocated. "
|
||||
"Ensure `alloc_slot` was called and the slot wasn't freed already."
|
||||
)
|
||||
if slot < 0 or slot >= self._max_slot_num:
|
||||
raise ValueError(
|
||||
f"Invalid slot id {slot}. Valid slot indices are in the range 0..{self._max_slot_num - 1}."
|
||||
)
|
||||
self._occupied_slots.remove(slot)
|
||||
self._free_slots.append(slot)
|
||||
|
||||
@property
|
||||
def meta(self) -> AuxBufferMeta:
|
||||
return self._meta
|
||||
|
||||
def fill_slot(self, slot: int, request: LlmRequest) -> None:
|
||||
if slot not in self._occupied_slots:
|
||||
raise ValueError(
|
||||
f"Cannot fill slot {slot}: slot is not currently allocated. "
|
||||
"Call `alloc_slot` first."
|
||||
)
|
||||
first_gen_tokens = request.get_last_tokens()
|
||||
draft_tokens = request.py_draft_tokens
|
||||
|
||||
if len(first_gen_tokens) > self._beam_width:
|
||||
raise ValueError(
|
||||
f"`first_gen_tokens` length ({len(first_gen_tokens)}) exceeds `beam_width` ({self._beam_width}). "
|
||||
"Consider truncating the token list or increasing the beam_width when creating the `AuxBuffer`."
|
||||
)
|
||||
if len(draft_tokens) > self._max_draft_len:
|
||||
raise ValueError(
|
||||
f"`draft_tokens` length ({len(draft_tokens)}) exceeds `max_draft_len` ({self._max_draft_len}). "
|
||||
"Consider truncating draft tokens or increasing `max_draft_len` when creating the `AuxBuffer`."
|
||||
)
|
||||
|
||||
self._first_tokens_buffer[slot][: len(first_gen_tokens)].copy_(
|
||||
torch.tensor(first_gen_tokens, dtype=torch.int32, device=self._device)
|
||||
)
|
||||
self._draft_tokens_buffer[slot][: len(draft_tokens)].copy_(
|
||||
torch.tensor(draft_tokens, dtype=torch.int32, device=self._device)
|
||||
)
|
||||
self._slot_token_counts[slot] = (len(first_gen_tokens), len(draft_tokens))
|
||||
|
||||
def get_slot_tokens(self, slot: int) -> tuple[list[int], list[int]]:
|
||||
first_len, draft_len = self._slot_token_counts.get(
|
||||
slot, (self._beam_width, self._max_draft_len)
|
||||
)
|
||||
first_gen_tokens = self._first_tokens_buffer[slot][:first_len].tolist()
|
||||
draft_tokens = self._draft_tokens_buffer[slot][:draft_len].tolist()
|
||||
|
||||
return first_gen_tokens, draft_tokens
|
||||
216
tensorrt_llm/_torch/disaggregation/native/region/block.py
Normal file
216
tensorrt_llm/_torch/disaggregation/native/region/block.py
Normal file
@ -0,0 +1,216 @@
|
||||
import numpy as np
|
||||
|
||||
from tensorrt_llm._torch.disaggregation.base.region import (
|
||||
MemRegionGroup,
|
||||
RegionMapperBase,
|
||||
SpecRegion,
|
||||
SpecRegionPair,
|
||||
)
|
||||
from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo
|
||||
|
||||
|
||||
class IdentityMapper(RegionMapperBase):
|
||||
"""
|
||||
---- mapper_identity ----
|
||||
|
||||
Pass-through mapping. Do not change pointers or sizes.
|
||||
|
||||
src_ptrs: [ S0 ] [ S1 ] [ S2 ] ...
|
||||
| | |
|
||||
v v v
|
||||
dst_ptrs: [ D0 ] [ D1 ] [ D2 ] ...
|
||||
"""
|
||||
|
||||
def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair:
|
||||
src_group = src_regions.memory
|
||||
dst_group = dst_regions.memory
|
||||
assert len(src_group.ptrs) == len(dst_group.ptrs), (
|
||||
f"Number of regions of src({len(src_group.ptrs)}) and dst({len(dst_group.ptrs)}) must match"
|
||||
)
|
||||
new_src = MemRegionGroup(
|
||||
ptrs=list(src_group.ptrs), bytes_per_region=src_group.bytes_per_region
|
||||
)
|
||||
new_dst = MemRegionGroup(
|
||||
ptrs=list(dst_group.ptrs), bytes_per_region=dst_group.bytes_per_region
|
||||
)
|
||||
return SpecRegionPair(
|
||||
src=SpecRegion(memory=new_src, spec=src_regions.spec),
|
||||
dst=SpecRegion(memory=new_dst, spec=dst_regions.spec),
|
||||
)
|
||||
|
||||
|
||||
class HeadMatchMapper(RegionMapperBase):
|
||||
"""
|
||||
---- mapper_head_match ----
|
||||
|
||||
Move/copy entire contiguous block(s) (multi-layer fragment) as a single chunk.
|
||||
Align by whole fragment size (frag_size) and apply a constant source/destination block offset.
|
||||
|
||||
src_ptrs: [ S0 ] [ S1 ] ...
|
||||
| |
|
||||
+ src_off + src_off
|
||||
| |
|
||||
[ S0 + src_off ] [ S1 + src_off ] -> (each points to a frag of size frag_size)
|
||||
copy whole frag
|
||||
| |
|
||||
v v
|
||||
[ D0 + dst_off ] [ D1 + dst_off ] -> (destination frags)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transfer_layers: int,
|
||||
src_layer_off: int,
|
||||
dst_layer_off: int,
|
||||
self_ri: RankInfo,
|
||||
peer_ri: RankInfo,
|
||||
):
|
||||
self._kv_factor = self_ri.kv_factor
|
||||
self._frag_size = self._block_size(transfer_layers, self_ri)
|
||||
self._src_block_off = self._block_size(src_layer_off, self_ri)
|
||||
self._dst_block_off = self._block_size(dst_layer_off, peer_ri)
|
||||
|
||||
def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair:
|
||||
src_group = src_regions.memory
|
||||
dst_group = dst_regions.memory
|
||||
assert len(src_group.ptrs) == len(dst_group.ptrs), (
|
||||
f"Number of regions of src({len(src_group.ptrs)}) and dst({len(dst_group.ptrs)}) must match"
|
||||
)
|
||||
new_src_ptrs = [src_ptr + self._src_block_off for src_ptr in src_group.ptrs]
|
||||
new_dst_ptrs = [dst_ptr + self._dst_block_off for dst_ptr in dst_group.ptrs]
|
||||
new_src = MemRegionGroup(ptrs=new_src_ptrs, bytes_per_region=self._frag_size)
|
||||
new_dst = MemRegionGroup(ptrs=new_dst_ptrs, bytes_per_region=self._frag_size)
|
||||
return SpecRegionPair(
|
||||
src=SpecRegion(memory=new_src, spec=src_regions.spec),
|
||||
dst=SpecRegion(memory=new_dst, spec=dst_regions.spec),
|
||||
)
|
||||
|
||||
def _block_size(self, layer_num: int, ri: RankInfo) -> int:
|
||||
return (
|
||||
layer_num
|
||||
* ri.kv_factor
|
||||
* ri.kv_heads_per_rank
|
||||
* ri.tokens_per_block
|
||||
* ri.dims_per_head
|
||||
* ri.element_bytes
|
||||
)
|
||||
|
||||
|
||||
class HeadMismatchMapper(RegionMapperBase):
|
||||
"""
|
||||
---- mapper_head_mismatch ----
|
||||
|
||||
Fine-grained mapping when head counts or TP/DP partitioning differ.
|
||||
Split layers into per-head (or contiguous-heads) fragments and map them individually.
|
||||
Handles kv_factor (e.g., key+value duplication) and TP/DP head offsets.
|
||||
|
||||
Source (layers x heads):
|
||||
L0: [S00 S01] [S02 S03] ...
|
||||
L1: [S10 S11] [S12 S13] ...
|
||||
|
||||
Destination (layers x heads, different layout possible):
|
||||
L0': [D00] [D01] [D02] ...
|
||||
L1': [D10] [D11] ...
|
||||
|
||||
Mapping (each arrow = copy cont_heads_frag):
|
||||
[S00 S01] -> [D00]
|
||||
[S02 S03] -> [D01]
|
||||
[S10 S11] -> [D02]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transfer_layers: int,
|
||||
src_layer_off: int,
|
||||
peer_layer_off: int,
|
||||
self_ri: RankInfo,
|
||||
peer_ri: RankInfo,
|
||||
):
|
||||
self._ri = self_ri
|
||||
self._peer_ri = peer_ri
|
||||
self._src_layer_off = src_layer_off
|
||||
|
||||
kv_factor = self_ri.kv_factor
|
||||
self_tp_per_dp = self_ri.tp_size // self_ri.dp_size
|
||||
peer_tp_per_dp = peer_ri.tp_size // peer_ri.dp_size
|
||||
self_tp_rank = self_ri.tp_rank
|
||||
peer_tp_rank = peer_ri.tp_rank
|
||||
|
||||
bytes_per_head = self._ri.tokens_per_block * self._ri.dims_per_head * self._ri.element_bytes
|
||||
self._bytes_cont_heads = (
|
||||
min(self._ri.kv_heads_per_rank, peer_ri.kv_heads_per_rank) * bytes_per_head
|
||||
)
|
||||
|
||||
self._src_head_off, self._dst_head_off = self._compute_head_offsets(
|
||||
self_tp_per_dp,
|
||||
peer_tp_per_dp,
|
||||
self_tp_rank,
|
||||
peer_tp_rank,
|
||||
self._bytes_cont_heads,
|
||||
)
|
||||
self._layer_indices = np.arange(transfer_layers, dtype=np.int64)
|
||||
self._kv_indices = np.arange(kv_factor, dtype=np.int64)
|
||||
self._peer_layer_off = peer_layer_off
|
||||
|
||||
def map(self, src_regions: SpecRegion, dst_regions: SpecRegion) -> SpecRegionPair:
|
||||
src_group = src_regions.memory
|
||||
dst_group = dst_regions.memory
|
||||
assert len(src_group.ptrs) == len(dst_group.ptrs), (
|
||||
f"Number of regions of src({len(src_group.ptrs)}) and dst({len(dst_group.ptrs)}) must match"
|
||||
)
|
||||
src_bases = np.array(src_group.ptrs, dtype=np.int64)
|
||||
dst_bases = np.array(dst_group.ptrs, dtype=np.int64)
|
||||
src_frags = self._get_frags(
|
||||
bases=src_bases,
|
||||
layer_indices=self._src_layer_off + self._layer_indices,
|
||||
layer_kv_num=self._get_layer_kv_num(self._ri),
|
||||
kv_indices=self._kv_indices,
|
||||
head_off=self._src_head_off,
|
||||
kv_factor=self._kv_indices.size,
|
||||
)
|
||||
dst_frags = self._get_frags(
|
||||
bases=dst_bases,
|
||||
layer_indices=self._peer_layer_off + self._layer_indices,
|
||||
layer_kv_num=self._get_layer_kv_num(self._peer_ri),
|
||||
kv_indices=self._kv_indices,
|
||||
head_off=self._dst_head_off,
|
||||
kv_factor=self._kv_indices.size,
|
||||
)
|
||||
all_src_ptrs = [int(x) for x in src_frags.flatten()]
|
||||
all_dst_ptrs = [int(x) for x in dst_frags.flatten()]
|
||||
new_src = MemRegionGroup(ptrs=all_src_ptrs, bytes_per_region=self._bytes_cont_heads)
|
||||
new_dst = MemRegionGroup(ptrs=all_dst_ptrs, bytes_per_region=self._bytes_cont_heads)
|
||||
return SpecRegionPair(
|
||||
src=SpecRegion(memory=new_src, spec=src_regions.spec),
|
||||
dst=SpecRegion(memory=new_dst, spec=dst_regions.spec),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_head_offsets(
|
||||
self_tp_per_dp: int,
|
||||
peer_tp_per_dp: int,
|
||||
self_tp_rank: int,
|
||||
peer_tp_rank: int,
|
||||
bytes_cont_heads: int,
|
||||
) -> tuple[int, int]:
|
||||
if self_tp_per_dp == peer_tp_per_dp:
|
||||
return 0, 0
|
||||
ratio = max(self_tp_per_dp, peer_tp_per_dp) // min(self_tp_per_dp, peer_tp_per_dp)
|
||||
if self_tp_per_dp < peer_tp_per_dp:
|
||||
return (peer_tp_rank % ratio) * bytes_cont_heads, 0
|
||||
else:
|
||||
return 0, (self_tp_rank % ratio) * bytes_cont_heads
|
||||
|
||||
@staticmethod
|
||||
def _get_layer_kv_num(ri: RankInfo) -> int:
|
||||
return ri.kv_heads_per_rank * ri.tokens_per_block * ri.dims_per_head * ri.element_bytes
|
||||
|
||||
@staticmethod
|
||||
def _get_frags(bases, layer_indices, layer_kv_num, kv_indices, head_off, kv_factor):
|
||||
layer_num = layer_kv_num * kv_factor
|
||||
return (
|
||||
bases[:, None, None]
|
||||
+ layer_num * layer_indices[None, :, None]
|
||||
+ layer_kv_num * kv_indices[None, None, :]
|
||||
+ head_off
|
||||
)
|
||||
126
tensorrt_llm/_torch/disaggregation/native/region/page.py
Normal file
126
tensorrt_llm/_torch/disaggregation/native/region/page.py
Normal file
@ -0,0 +1,126 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Set
|
||||
|
||||
import numpy as np
|
||||
|
||||
BUFFER_ENTRY_DTYPE = np.dtype(
|
||||
[
|
||||
("layer_id", np.uint32),
|
||||
("role", np.uint32),
|
||||
("offset", np.uint32),
|
||||
("size", np.uint32),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolDescriptor:
|
||||
"""
|
||||
Pool descriptor containing memory layout and buffer information
|
||||
|
||||
One pool contains multiple buffer entries, each representing a (layer_id, role) combination.
|
||||
"""
|
||||
|
||||
base_address: int # (uint64)
|
||||
slot_bytes: int
|
||||
num_slots: int
|
||||
|
||||
# Buffer entries: flattened array of all (layer_id, role, offset, size) in this pool
|
||||
buffer_entries: np.ndarray # dtype=BUFFER_ENTRY_DTYPE
|
||||
|
||||
@property
|
||||
def pool_bytes(self) -> int:
|
||||
return self.slot_bytes * self.num_slots
|
||||
|
||||
@property
|
||||
def unique_layers(self) -> Set[int]:
|
||||
return set(int(entry["layer_id"]) for entry in self.buffer_entries)
|
||||
|
||||
@property
|
||||
def unique_roles(self) -> Set[int]:
|
||||
return set(int(entry["role"]) for entry in self.buffer_entries)
|
||||
|
||||
def get_slot_address(self, slot_id: int) -> int:
|
||||
if slot_id >= self.num_slots:
|
||||
raise ValueError(f"slot_id {slot_id} >= num_slots {self.num_slots}")
|
||||
return self.base_address + slot_id * self.slot_bytes
|
||||
|
||||
def get_device_pointer(self, slot_id: int, layer_id: int, role_enum: int) -> int:
|
||||
if slot_id >= self.num_slots:
|
||||
raise ValueError(f"slot_id {slot_id} >= num_slots {self.num_slots}")
|
||||
|
||||
for entry in self.buffer_entries:
|
||||
if entry["layer_id"] == layer_id and entry["role"] == role_enum:
|
||||
slot_base = self.base_address + slot_id * self.slot_bytes
|
||||
return slot_base + int(entry["offset"])
|
||||
|
||||
raise ValueError(f"Buffer not found: layer_id={layer_id}, role_enum={role_enum}")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PoolDescriptor(base=0x{self.base_address:x}, "
|
||||
f"slot_bytes={self.slot_bytes}, num_slots={self.num_slots}, "
|
||||
f"layers={len(self.unique_layers)}, roles={len(self.unique_roles)}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCachePageTable:
|
||||
"""
|
||||
Multi-dimensional KV cache page table
|
||||
|
||||
Structure:
|
||||
KVCachePageTable
|
||||
└── PoolGroups (List[List[PoolDescriptor]])
|
||||
├── PoolGroup 0: List of PoolDescriptors
|
||||
│ ├── Pool 0: PoolDescriptor
|
||||
│ └── Pool 1: PoolDescriptor
|
||||
│
|
||||
└── PoolGroup 1: List of PoolDescriptors
|
||||
├── Pool 0: PoolDescriptor
|
||||
└── Pool 1: PoolDescriptor
|
||||
|
||||
Relationships:
|
||||
- pools[pg_idx] = List[PoolDescriptor] (all pools in same PoolGroup)
|
||||
- All pools in pools[pg_idx] share the same lifecycle
|
||||
"""
|
||||
|
||||
tokens_per_block: int
|
||||
num_layers: int
|
||||
pools: List[List[PoolDescriptor]] # pools[pg_idx][pool_idx] → PoolDescriptor
|
||||
|
||||
@property
|
||||
def num_pool_groups(self) -> int:
|
||||
return len(self.pools)
|
||||
|
||||
@property
|
||||
def total_pools(self) -> int:
|
||||
return sum(len(pg_pools) for pg_pools in self.pools)
|
||||
|
||||
@property
|
||||
def total_buffer_entries(self) -> int:
|
||||
return sum(pool.num_buffer_entries for pg_pools in self.pools for pool in pg_pools)
|
||||
|
||||
@property
|
||||
def total_pool_bytes(self) -> int:
|
||||
return sum(pool.pool_bytes for pg_pools in self.pools for pool in pg_pools)
|
||||
|
||||
@property
|
||||
def total_slots(self) -> int:
|
||||
return sum(pool.num_slots for pg_pools in self.pools for pool in pg_pools)
|
||||
|
||||
def get_pool(self, pg_idx: int, pool_idx: int) -> PoolDescriptor:
|
||||
return self.pools[pg_idx][pool_idx]
|
||||
|
||||
def get_device_pointer(
|
||||
self, pg_idx: int, pool_idx: int, slot_id: int, layer_id: int, role: str
|
||||
) -> int:
|
||||
pool = self.pools[pg_idx][pool_idx]
|
||||
role_enum = self.role_to_enum(role)
|
||||
return pool.get_device_pointer(slot_id, layer_id, role_enum)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"KVCachePageTable(poolgroups={self.num_pool_groups}, "
|
||||
f"pools={self.total_pools}, layers={self.num_layers})"
|
||||
)
|
||||
89
tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py
Normal file
89
tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py
Normal file
@ -0,0 +1,89 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from tensorrt_llm._torch.disaggregation.base.region import (
|
||||
DataLayout,
|
||||
MemRegionGroup,
|
||||
RegionExtractorBase,
|
||||
SpecRegion,
|
||||
)
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm._utils import get_size_in_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVPoolAttrs:
|
||||
"""Attributes for a single (primary) KV memory pool."""
|
||||
|
||||
pool_ptrs: List[int]
|
||||
block_bytes: List[int]
|
||||
|
||||
|
||||
class KVRegionExtractorV1(RegionExtractorBase):
|
||||
"""
|
||||
Descriptor and region extractor for KV cache pool managed by KVCacheManager.
|
||||
Provides region descriptors for adapting block-wise view.
|
||||
"""
|
||||
|
||||
def __init__(self, kv_arg: KVCacheManager | KVPoolAttrs):
|
||||
if isinstance(kv_arg, KVPoolAttrs):
|
||||
self._kv_pool_attrs = kv_arg
|
||||
elif isinstance(kv_arg, KVCacheManager):
|
||||
self._kv_pool_attrs = self._attrs_from_manager(kv_arg)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"kv_cache_manager must be KVCacheManager or KVPoolAttrs, got {type(kv_arg)}"
|
||||
)
|
||||
self._data_layout = DataLayout.HND
|
||||
|
||||
@staticmethod
|
||||
def _attrs_from_manager(manager: KVCacheManager) -> KVPoolAttrs:
|
||||
try:
|
||||
pools = manager.get_unique_primary_pool()
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to get pool(s): {ex}")
|
||||
|
||||
pool_list = list(pools) if isinstance(pools, (list, tuple)) else [pools]
|
||||
elem_bytes = get_size_in_bytes(1, manager.dtype)
|
||||
ptrs, block_sizes = [], []
|
||||
|
||||
for p in pool_list:
|
||||
if hasattr(p, "data_ptr") and callable(p.data_ptr):
|
||||
try:
|
||||
ptr = int(p.data_ptr())
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Fail to call data_ptr(): {ex}")
|
||||
elif isinstance(p, int):
|
||||
ptr = int(p)
|
||||
else:
|
||||
raise ValueError(f"Pool object lacks 'data_ptr' and is not int: {p!r}")
|
||||
ptrs.append(ptr)
|
||||
|
||||
try:
|
||||
if hasattr(p, "__getitem__") and hasattr(p[0], "numel"):
|
||||
n = int(p[0].numel())
|
||||
elif hasattr(p, "numel") and callable(p.numel):
|
||||
n = int(p.numel())
|
||||
else:
|
||||
raise RuntimeError("Cannot determine element count")
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failed to get block size from {p!r}: {ex}")
|
||||
|
||||
block_sizes.append(n * elem_bytes)
|
||||
|
||||
return KVPoolAttrs(pool_ptrs=ptrs, block_bytes=block_sizes)
|
||||
|
||||
def extract(self, region_ids: List[int]) -> SpecRegion:
|
||||
"""
|
||||
Given a list of region_ids, returns a single SpecRegion,
|
||||
whose memory is a MemRegionGroup containing all blocks described
|
||||
by region_ids.
|
||||
"""
|
||||
assert len(self._kv_pool_attrs.pool_ptrs) == 1
|
||||
pool_idx = 0
|
||||
attrs = self._kv_pool_attrs
|
||||
ptrs = [
|
||||
attrs.pool_ptrs[pool_idx] + block_id * attrs.block_bytes[0] for block_id in region_ids
|
||||
]
|
||||
memory = MemRegionGroup(ptrs=ptrs, bytes_per_region=attrs.block_bytes[0])
|
||||
return SpecRegion(memory=memory)
|
||||
@ -0,0 +1,77 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorrt_llm._torch.disaggregation.native.region.page import (
|
||||
BUFFER_ENTRY_DTYPE,
|
||||
KVCachePageTable,
|
||||
PoolDescriptor,
|
||||
)
|
||||
from tensorrt_llm.runtime.kv_cache_manager_v2 import CacheTier, KVCacheManager
|
||||
|
||||
|
||||
def build_page_table(manager: KVCacheManager) -> KVCachePageTable:
|
||||
storage = manager._storage
|
||||
config = manager._init_config
|
||||
|
||||
gpu_level = 0
|
||||
for level_idx, cache_tier_config in enumerate(config.cache_tiers):
|
||||
if cache_tier_config.tier == CacheTier.GPU_MEM:
|
||||
gpu_level = level_idx
|
||||
break
|
||||
|
||||
buffer_by_pool = defaultdict(list)
|
||||
|
||||
lc_to_pg_cache = {}
|
||||
|
||||
for buffer_id, attr in storage._buffer_attr.items():
|
||||
layer_id, role = buffer_id
|
||||
|
||||
lc_id = attr.life_cycle_id
|
||||
if lc_id not in lc_to_pg_cache:
|
||||
lc_to_pg_cache[lc_id] = storage.get_pool_group_index(lc_id)
|
||||
pg_idx = lc_to_pg_cache[lc_id]
|
||||
|
||||
pool_idx = attr.pool_index
|
||||
pool_key = (pg_idx, pool_idx)
|
||||
|
||||
buffer_by_pool[pool_key].append((layer_id, role, attr.offset, attr.size))
|
||||
|
||||
pools = []
|
||||
num_pool_groups = storage.num_pool_groups
|
||||
pool_group_storage = storage._levels[gpu_level].storage._pool_groups
|
||||
|
||||
for pg_idx in range(num_pool_groups):
|
||||
pool_group = pool_group_storage[pg_idx]
|
||||
num_pools = pool_group.num_pools
|
||||
pg_pools = []
|
||||
|
||||
for pool_idx in range(num_pools):
|
||||
pool = pool_group._pools[pool_idx]
|
||||
|
||||
base_address = int(pool.slot_address(0))
|
||||
slot_bytes = int(pool.slot_size)
|
||||
num_slots = int(pool.num_slots)
|
||||
|
||||
pool_key = (pg_idx, pool_idx)
|
||||
buffers_info = buffer_by_pool.get(pool_key, [])
|
||||
|
||||
if buffers_info:
|
||||
buffer_entries = np.array(buffers_info, dtype=BUFFER_ENTRY_DTYPE)
|
||||
else:
|
||||
buffer_entries = np.array([], dtype=BUFFER_ENTRY_DTYPE)
|
||||
|
||||
pool_desc = PoolDescriptor(
|
||||
base_address=base_address,
|
||||
slot_bytes=slot_bytes,
|
||||
num_slots=num_slots,
|
||||
buffer_entries=buffer_entries,
|
||||
)
|
||||
pg_pools.append(pool_desc)
|
||||
pools.append(pg_pools)
|
||||
|
||||
return KVCachePageTable(
|
||||
tokens_per_block=config.tokens_per_block,
|
||||
num_layers=len(config.layers),
|
||||
pools=pools,
|
||||
)
|
||||
@ -38,6 +38,10 @@ l0_a10:
|
||||
- unittest/disaggregated/test_messenger.py
|
||||
- unittest/disaggregated/test_disagg_cluster_manager_worker.py
|
||||
- unittest/disaggregated/test_cluster_storage.py
|
||||
- unittest/disaggregated/test_extractor.py
|
||||
- unittest/disaggregated/test_extractor_v2.py
|
||||
- unittest/disaggregated/test_peer.py
|
||||
- unittest/disaggregated/region/test_block.py
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
|
||||
|
||||
@ -39,6 +39,10 @@ l0_h100:
|
||||
- unittest/disaggregated/test_remoteDictionary.py
|
||||
- unittest/disaggregated/test_agent_multi_backends.py
|
||||
- unittest/disaggregated/test_messenger.py
|
||||
- unittest/disaggregated/test_extractor.py
|
||||
- unittest/disaggregated/test_extractor_v2.py
|
||||
- unittest/disaggregated/test_peer.py
|
||||
- unittest/disaggregated/region/test_block.py
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_without_reuse
|
||||
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa_reuse
|
||||
|
||||
146
tests/unittest/disaggregated/region/test_block.py
Normal file
146
tests/unittest/disaggregated/region/test_block.py
Normal file
@ -0,0 +1,146 @@
|
||||
from tensorrt_llm._torch.disaggregation.base.region import (
|
||||
MemRegionGroup,
|
||||
SpecRegion,
|
||||
SpecRegionPair,
|
||||
)
|
||||
from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo
|
||||
from tensorrt_llm._torch.disaggregation.native.region.block import (
|
||||
HeadMatchMapper,
|
||||
HeadMismatchMapper,
|
||||
IdentityMapper,
|
||||
)
|
||||
|
||||
|
||||
def make_rankinfo(
|
||||
kv_heads_per_rank=2,
|
||||
tokens_per_block=4,
|
||||
dims_per_head=2,
|
||||
element_bytes=1,
|
||||
tp_size=2,
|
||||
tp_rank=0,
|
||||
dp_size=1,
|
||||
dp_rank=0,
|
||||
pp_size=1,
|
||||
pp_rank=0,
|
||||
cp_size=1,
|
||||
cp_rank=0,
|
||||
is_mla=False,
|
||||
):
|
||||
return RankInfo(
|
||||
instance_name="rank",
|
||||
instance_rank=0,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
pp_size=pp_size,
|
||||
pp_rank=pp_rank,
|
||||
cp_size=cp_size,
|
||||
cp_rank=cp_rank,
|
||||
device_id=0,
|
||||
kv_heads_per_rank=kv_heads_per_rank,
|
||||
tokens_per_block=tokens_per_block,
|
||||
dims_per_head=dims_per_head,
|
||||
element_bytes=element_bytes,
|
||||
enable_attention_dp=False,
|
||||
is_mla=is_mla,
|
||||
layer_num_per_pp=[1],
|
||||
kv_ptrs=[],
|
||||
aux_ptrs=[],
|
||||
server_endpoint="",
|
||||
self_endpoint="",
|
||||
transfer_engine_info=b"",
|
||||
aux_meta=None,
|
||||
)
|
||||
|
||||
|
||||
def test_mem_region_group():
|
||||
ptrs = [11, 22, 33]
|
||||
bytes_per_region = 16
|
||||
region = MemRegionGroup(ptrs=ptrs, bytes_per_region=bytes_per_region)
|
||||
assert list(region.ptrs) == ptrs
|
||||
assert region.bytes_per_region == bytes_per_region
|
||||
|
||||
|
||||
def test_spec_region_and_spec_region_pair():
|
||||
group_src = MemRegionGroup(ptrs=[101, 202], bytes_per_region=8)
|
||||
group_dst = MemRegionGroup(ptrs=[303, 404], bytes_per_region=8)
|
||||
spec_src = SpecRegion(memory=group_src, spec="spec_src")
|
||||
spec_dst = SpecRegion(memory=group_dst, spec="spec_dst")
|
||||
assert isinstance(spec_src, SpecRegion)
|
||||
assert isinstance(spec_dst, SpecRegion)
|
||||
pair = SpecRegionPair(src=spec_src, dst=spec_dst)
|
||||
assert isinstance(pair, SpecRegionPair)
|
||||
assert pair.src.memory.ptrs == [101, 202]
|
||||
assert pair.dst.memory.ptrs == [303, 404]
|
||||
assert pair.src.spec == "spec_src"
|
||||
assert pair.dst.spec == "spec_dst"
|
||||
|
||||
|
||||
def test_identity_mapper():
|
||||
src_group = MemRegionGroup(ptrs=[100, 200], bytes_per_region=32)
|
||||
dst_group = MemRegionGroup(ptrs=[300, 400], bytes_per_region=32)
|
||||
src_spec = SpecRegion(memory=src_group, spec="a")
|
||||
dst_spec = SpecRegion(memory=dst_group, spec="b")
|
||||
mapper = IdentityMapper()
|
||||
result = mapper.map(src_spec, dst_spec)
|
||||
assert isinstance(result, SpecRegionPair)
|
||||
assert list(result.src.memory.ptrs) == [100, 200]
|
||||
assert list(result.dst.memory.ptrs) == [300, 400]
|
||||
assert result.src.memory.bytes_per_region == 32
|
||||
assert result.dst.memory.bytes_per_region == 32
|
||||
|
||||
|
||||
def test_head_match_mapper():
|
||||
self_ri = make_rankinfo(kv_heads_per_rank=2)
|
||||
peer_ri = make_rankinfo(kv_heads_per_rank=2)
|
||||
transfer_layers = 2
|
||||
src_layer_off = 1
|
||||
dst_layer_off = 1
|
||||
src_group = MemRegionGroup(ptrs=[10, 20], bytes_per_region=1)
|
||||
dst_group = MemRegionGroup(ptrs=[30, 40], bytes_per_region=1)
|
||||
src_spec = SpecRegion(memory=src_group, spec="srcspec")
|
||||
dst_spec = SpecRegion(memory=dst_group, spec="dstspec")
|
||||
mapper = HeadMatchMapper(transfer_layers, src_layer_off, dst_layer_off, self_ri, peer_ri)
|
||||
result = mapper.map(src_spec, dst_spec)
|
||||
expected_off = (
|
||||
transfer_layers
|
||||
* mapper._kv_factor
|
||||
* self_ri.kv_heads_per_rank
|
||||
* self_ri.tokens_per_block
|
||||
* self_ri.dims_per_head
|
||||
* self_ri.element_bytes
|
||||
)
|
||||
assert list(result.src.memory.ptrs) == [10 + mapper._src_block_off, 20 + mapper._src_block_off]
|
||||
assert list(result.dst.memory.ptrs) == [30 + mapper._dst_block_off, 40 + mapper._dst_block_off]
|
||||
assert result.src.memory.bytes_per_region == expected_off
|
||||
assert result.dst.memory.bytes_per_region == expected_off
|
||||
|
||||
|
||||
def test_head_mismatch_mapper():
|
||||
self_ri = make_rankinfo(kv_heads_per_rank=2, tp_size=2, tp_rank=1)
|
||||
peer_ri = make_rankinfo(kv_heads_per_rank=4, tp_size=4, tp_rank=2)
|
||||
transfer_layers = 1
|
||||
src_layer_off = 0
|
||||
peer_layer_off = 1
|
||||
src_group = MemRegionGroup(ptrs=[111], bytes_per_region=32)
|
||||
dst_group = MemRegionGroup(ptrs=[222], bytes_per_region=32)
|
||||
src_spec = SpecRegion(memory=src_group, spec="srcspec")
|
||||
dst_spec = SpecRegion(memory=dst_group, spec="dstspec")
|
||||
mapper = HeadMismatchMapper(transfer_layers, src_layer_off, peer_layer_off, self_ri, peer_ri)
|
||||
result = mapper.map(src_spec, dst_spec)
|
||||
expected_frag_count = self_ri.kv_factor * transfer_layers
|
||||
assert isinstance(result, SpecRegionPair)
|
||||
assert len(result.src.memory.ptrs) == expected_frag_count
|
||||
assert len(result.dst.memory.ptrs) == expected_frag_count
|
||||
assert all(isinstance(x, int) for x in result.src.memory.ptrs)
|
||||
assert all(isinstance(x, int) for x in result.dst.memory.ptrs)
|
||||
assert result.src.memory.bytes_per_region == mapper._bytes_cont_heads
|
||||
assert result.dst.memory.bytes_per_region == mapper._bytes_cont_heads
|
||||
|
||||
|
||||
def test_rankinfo_kv_factor():
|
||||
ri1 = make_rankinfo(is_mla=False)
|
||||
ri2 = make_rankinfo(is_mla=True)
|
||||
assert ri1.kv_factor == 2
|
||||
assert ri2.kv_factor == 1
|
||||
101
tests/unittest/disaggregated/test_extractor.py
Normal file
101
tests/unittest/disaggregated/test_extractor.py
Normal file
@ -0,0 +1,101 @@
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch.disaggregation.base.region import MemRegionGroup, SpecRegion
|
||||
from tensorrt_llm._torch.disaggregation.resource.kv_extractor import KVRegionExtractorV1
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import (
|
||||
CacheTypeCpp,
|
||||
DataType,
|
||||
KvCacheConfig,
|
||||
KVCacheManager,
|
||||
Mapping,
|
||||
)
|
||||
|
||||
|
||||
class DummyRankInfo:
|
||||
instance_name = "dummy"
|
||||
instance_rank = 0
|
||||
tp_size = 1
|
||||
tp_rank = 0
|
||||
pp_size = 1
|
||||
pp_rank = 0
|
||||
dp_size = 1
|
||||
dp_rank = 0
|
||||
cp_size = 1
|
||||
cp_rank = 0
|
||||
device_id = 0
|
||||
kv_heads_per_rank = 8
|
||||
tokens_per_block = 32
|
||||
dims_per_head = 16
|
||||
element_bytes = 2
|
||||
enable_attention_dp = False
|
||||
is_mla = False
|
||||
layer_num_per_pp = [1]
|
||||
|
||||
@property
|
||||
def kv_factor(self) -> int:
|
||||
return 2 if not self.is_mla else 1
|
||||
|
||||
|
||||
@pytest.mark.cuda
|
||||
def test_extract():
|
||||
num_layers = 1
|
||||
num_kv_heads = 8
|
||||
head_dim = 16
|
||||
tokens_per_block = 32
|
||||
max_seq_len = 128
|
||||
max_batch_size = 2
|
||||
dtype = DataType.HALF
|
||||
mapping = Mapping(world_size=1, rank=0, tp_size=1, pp_size=1, gpus_per_node=1)
|
||||
kv_cache_config = KvCacheConfig(
|
||||
max_tokens=512,
|
||||
free_gpu_memory_fraction=0.1,
|
||||
max_attention_window=None,
|
||||
enable_block_reuse=False,
|
||||
event_buffer_max_size=0,
|
||||
onboard_blocks=0,
|
||||
host_cache_size=0,
|
||||
enable_partial_reuse=False,
|
||||
copy_on_partial_reuse=False,
|
||||
sink_token_length=0,
|
||||
max_util_for_resume=1,
|
||||
)
|
||||
kv_cache_type = CacheTypeCpp.SELF
|
||||
|
||||
manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
kv_cache_type=kv_cache_type,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
extractor = KVRegionExtractorV1(manager)
|
||||
region_ids = [0, 1]
|
||||
spec_region = extractor.extract(region_ids)
|
||||
|
||||
assert isinstance(spec_region, SpecRegion)
|
||||
memory = spec_region.memory
|
||||
assert isinstance(memory, MemRegionGroup)
|
||||
assert len(memory.ptrs) == len(region_ids)
|
||||
assert memory.bytes_per_region > 0
|
||||
|
||||
pool_ptrs = manager.get_unique_primary_pool()
|
||||
if hasattr(pool_ptrs, "__getitem__"):
|
||||
if hasattr(pool_ptrs[0], "data_ptr"):
|
||||
pool_base_ptr = int(pool_ptrs[0].data_ptr())
|
||||
else:
|
||||
pool_base_ptr = int(pool_ptrs[0])
|
||||
else:
|
||||
pool_base_ptr = (
|
||||
int(pool_ptrs.data_ptr()) if hasattr(pool_ptrs, "data_ptr") else int(pool_ptrs)
|
||||
)
|
||||
expected_block_bytes = memory.bytes_per_region
|
||||
expected_ptrs = [pool_base_ptr + block_id * expected_block_bytes for block_id in region_ids]
|
||||
assert list(memory.ptrs) == expected_ptrs
|
||||
|
||||
manager.shutdown()
|
||||
78
tests/unittest/disaggregated/test_extractor_v2.py
Normal file
78
tests/unittest/disaggregated/test_extractor_v2.py
Normal file
@ -0,0 +1,78 @@
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch.disaggregation.resource.kv_extractor_v2 import build_page_table
|
||||
from tensorrt_llm.runtime.kv_cache_manager_v2 import (
|
||||
AttentionLayerConfig,
|
||||
BufferConfig,
|
||||
GpuCacheTierConfig,
|
||||
KVCacheManager,
|
||||
KVCacheManagerConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_manager():
|
||||
layers = [
|
||||
AttentionLayerConfig(
|
||||
layer_id=0,
|
||||
sliding_window_size=None,
|
||||
num_sink_tokens=0,
|
||||
buffers=[
|
||||
BufferConfig(role=0, size=8192),
|
||||
BufferConfig(role=1, size=8192),
|
||||
],
|
||||
),
|
||||
AttentionLayerConfig(
|
||||
layer_id=1,
|
||||
sliding_window_size=None,
|
||||
num_sink_tokens=0,
|
||||
buffers=[
|
||||
BufferConfig(role=0, size=8192),
|
||||
BufferConfig(role=1, size=8192),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
cache_tiers = [
|
||||
GpuCacheTierConfig(
|
||||
quota=100 * 1024 * 1024, # 100MB
|
||||
),
|
||||
]
|
||||
|
||||
config = KVCacheManagerConfig(
|
||||
tokens_per_block=64,
|
||||
layers=layers,
|
||||
vocab_size=50257,
|
||||
cache_tiers=cache_tiers,
|
||||
)
|
||||
|
||||
return KVCacheManager(config)
|
||||
|
||||
|
||||
def test_build_page_table(simple_manager):
|
||||
page_table = build_page_table(simple_manager)
|
||||
|
||||
# Check basic properties
|
||||
assert page_table.tokens_per_block == 64
|
||||
assert page_table.num_layers == 2
|
||||
assert page_table.num_pool_groups >= 1
|
||||
assert page_table.total_pools > 0
|
||||
|
||||
# Check pools are created
|
||||
assert len(page_table.pools) > 0
|
||||
assert all(len(pg_pools) > 0 for pg_pools in page_table.pools)
|
||||
|
||||
# Check first pool has valid properties
|
||||
pool = page_table.pools[0][0]
|
||||
assert pool.base_address > 0
|
||||
assert pool.slot_bytes > 0
|
||||
assert pool.num_slots > 0
|
||||
assert pool.pool_bytes == pool.slot_bytes * pool.num_slots
|
||||
|
||||
# Check buffer entries exist
|
||||
assert len(pool.buffer_entries) > 0
|
||||
|
||||
print(f"\n Page table: {page_table}")
|
||||
print(f" Total pools: {page_table.total_pools}")
|
||||
print(f" Pools: {page_table.pools}")
|
||||
print(f" Total size: {page_table.total_pool_bytes / (1024**2):.2f} MB")
|
||||
297
tests/unittest/disaggregated/test_peer.py
Normal file
297
tests/unittest/disaggregated/test_peer.py
Normal file
@ -0,0 +1,297 @@
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch.disaggregation.native.peer import PeerOverlap, PeerRegistrar, RankInfo
|
||||
from tensorrt_llm._torch.disaggregation.native.region.block import (
|
||||
HeadMatchMapper,
|
||||
HeadMismatchMapper,
|
||||
IdentityMapper,
|
||||
)
|
||||
from tensorrt_llm._torch.disaggregation.resource.kv_extractor import (
|
||||
KVPoolAttrs,
|
||||
KVRegionExtractorV1,
|
||||
)
|
||||
|
||||
|
||||
def make_rankinfo(
|
||||
instance_name="self",
|
||||
instance_rank=0,
|
||||
tp_size=2,
|
||||
tp_rank=0,
|
||||
pp_size=1,
|
||||
pp_rank=0,
|
||||
dp_size=1,
|
||||
dp_rank=0,
|
||||
cp_size=1,
|
||||
cp_rank=0,
|
||||
kv_heads_per_rank=2,
|
||||
tokens_per_block=16,
|
||||
dims_per_head=8,
|
||||
element_bytes=2,
|
||||
is_mla=False,
|
||||
enable_attention_dp=False,
|
||||
layer_num_per_pp=None,
|
||||
):
|
||||
if layer_num_per_pp is None:
|
||||
layer_num_per_pp = [2] * pp_size
|
||||
return RankInfo(
|
||||
instance_name=instance_name,
|
||||
instance_rank=instance_rank,
|
||||
tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
pp_size=pp_size,
|
||||
pp_rank=pp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
cp_size=cp_size,
|
||||
cp_rank=cp_rank,
|
||||
device_id=0,
|
||||
kv_heads_per_rank=kv_heads_per_rank,
|
||||
tokens_per_block=tokens_per_block,
|
||||
dims_per_head=dims_per_head,
|
||||
element_bytes=element_bytes,
|
||||
enable_attention_dp=enable_attention_dp,
|
||||
is_mla=is_mla,
|
||||
layer_num_per_pp=layer_num_per_pp,
|
||||
kv_ptrs=[],
|
||||
aux_ptrs=[],
|
||||
server_endpoint="",
|
||||
self_endpoint="",
|
||||
transfer_engine_info=b"",
|
||||
aux_meta=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_peer_registrar_and_peer_ri(self_ri, peer_ri):
|
||||
pool_attrs = KVPoolAttrs(pool_ptrs=[1234], block_bytes=[1024])
|
||||
real_kv_extractor = KVRegionExtractorV1(pool_attrs)
|
||||
reg = PeerRegistrar(self_ri, real_kv_extractor)
|
||||
return reg, peer_ri
|
||||
|
||||
|
||||
def test_basic_overlap():
|
||||
self_ri = make_rankinfo(
|
||||
"self",
|
||||
pp_size=1,
|
||||
pp_rank=0,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
cp_size=1,
|
||||
cp_rank=0,
|
||||
layer_num_per_pp=[2],
|
||||
)
|
||||
peer_ri = make_rankinfo(
|
||||
"peer",
|
||||
pp_size=2,
|
||||
pp_rank=0,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
cp_size=1,
|
||||
cp_rank=0,
|
||||
layer_num_per_pp=[1, 1],
|
||||
)
|
||||
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
|
||||
overlap = reg.get_peer_overlap(peer_ri, peer_dp_rank=0)
|
||||
assert overlap.overlap_pp_size == 2
|
||||
assert overlap.target_peer_pp_layer_num == [1, 1]
|
||||
assert overlap.overlap_tp_size == 1
|
||||
assert overlap.ranks == [0, 1]
|
||||
|
||||
|
||||
def test_no_overlap():
|
||||
self_ri = make_rankinfo(
|
||||
"self",
|
||||
pp_size=1,
|
||||
pp_rank=0,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
cp_size=1,
|
||||
cp_rank=0,
|
||||
layer_num_per_pp=[0],
|
||||
)
|
||||
peer_ri = make_rankinfo(
|
||||
"peer",
|
||||
pp_size=1,
|
||||
pp_rank=0,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
cp_size=1,
|
||||
cp_rank=0,
|
||||
layer_num_per_pp=[2],
|
||||
)
|
||||
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
|
||||
overlap = reg.get_peer_overlap(peer_ri, 0)
|
||||
assert overlap.overlap_pp_size == 0
|
||||
assert overlap.target_peer_pp_layer_num == []
|
||||
assert overlap.ranks == []
|
||||
|
||||
|
||||
def test_pp_ratio_peer_smaller():
|
||||
self_ri = make_rankinfo(
|
||||
"self",
|
||||
pp_size=2,
|
||||
pp_rank=1,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
cp_size=1,
|
||||
cp_rank=0,
|
||||
layer_num_per_pp=[1, 2],
|
||||
)
|
||||
peer_ri = make_rankinfo(
|
||||
"peer",
|
||||
pp_size=1,
|
||||
pp_rank=0,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
cp_size=1,
|
||||
cp_rank=0,
|
||||
layer_num_per_pp=[3],
|
||||
)
|
||||
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
|
||||
overlap = reg.get_peer_overlap(peer_ri, 0)
|
||||
assert overlap.overlap_pp_size > 0
|
||||
assert sum(overlap.target_peer_pp_layer_num) > 0
|
||||
assert all(r >= 0 for r in overlap.ranks)
|
||||
|
||||
|
||||
def test_tp_overlap():
|
||||
self_ri = make_rankinfo(
|
||||
"self", tp_size=2, tp_rank=1, pp_size=1, pp_rank=0, cp_size=1, cp_rank=0
|
||||
)
|
||||
peer_ri = make_rankinfo(
|
||||
"peer", tp_size=4, tp_rank=0, pp_size=1, pp_rank=0, cp_size=1, cp_rank=0
|
||||
)
|
||||
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
|
||||
overlap = reg.get_peer_overlap(peer_ri, 0)
|
||||
assert overlap.overlap_tp_size in [1, 2]
|
||||
assert all(isinstance(r, int) for r in overlap.ranks)
|
||||
|
||||
|
||||
def test_cp_overlap():
|
||||
self_ri = make_rankinfo(
|
||||
"self", cp_size=2, cp_rank=1, pp_size=1, pp_rank=0, tp_size=1, tp_rank=0
|
||||
)
|
||||
peer_ri = make_rankinfo(
|
||||
"peer", cp_size=4, cp_rank=0, pp_size=1, pp_rank=0, tp_size=1, tp_rank=0
|
||||
)
|
||||
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
|
||||
overlap = reg.get_peer_overlap(peer_ri, 0)
|
||||
assert overlap.overlap_cp_size in [1, 2]
|
||||
assert all(isinstance(r, int) for r in overlap.ranks)
|
||||
|
||||
|
||||
def test_multiple_overlap():
|
||||
self_ri = make_rankinfo(
|
||||
"self",
|
||||
pp_size=2,
|
||||
pp_rank=1,
|
||||
tp_size=2,
|
||||
tp_rank=1,
|
||||
cp_size=2,
|
||||
cp_rank=1,
|
||||
layer_num_per_pp=[1, 2],
|
||||
)
|
||||
peer_ri = make_rankinfo(
|
||||
"peer",
|
||||
pp_size=4,
|
||||
pp_rank=2,
|
||||
tp_size=4,
|
||||
tp_rank=3,
|
||||
cp_size=4,
|
||||
cp_rank=0,
|
||||
layer_num_per_pp=[1, 1, 1, 1],
|
||||
)
|
||||
reg, peer_ri = _make_peer_registrar_and_peer_ri(self_ri, peer_ri)
|
||||
overlap = reg.get_peer_overlap(peer_ri, peer_dp_rank=0)
|
||||
expected_overlap = PeerOverlap(
|
||||
overlap_pp_size=2,
|
||||
overlap_tp_size=2,
|
||||
overlap_cp_size=2,
|
||||
duplicate_head_factor=1,
|
||||
peer_duplicate_head_factor=2,
|
||||
target_peer_pp_layer_num=[1, 1],
|
||||
ranks=[26, 27, 30, 31, 42, 43, 46, 47],
|
||||
)
|
||||
assert overlap == expected_overlap
|
||||
|
||||
|
||||
def _make_peer_registrar(self_rankinfo):
|
||||
pool_attrs = KVPoolAttrs(pool_ptrs=[1234], block_bytes=[1024])
|
||||
real_kv_extractor = KVRegionExtractorV1(pool_attrs)
|
||||
reg = PeerRegistrar(self_rankinfo, real_kv_extractor)
|
||||
return reg
|
||||
|
||||
|
||||
def test_peer_registrar_register_and_get():
|
||||
self_rankinfo = make_rankinfo(instance_name="local")
|
||||
reg = _make_peer_registrar(self_rankinfo)
|
||||
peer_ri = make_rankinfo(instance_name="peer", instance_rank=1, layer_num_per_pp=[2])
|
||||
reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri)
|
||||
assert reg.get_peer_rank_info("peer", 1) == peer_ri
|
||||
|
||||
|
||||
def test_peer_registrar_unregister():
|
||||
self_rankinfo = make_rankinfo(instance_name="local")
|
||||
reg = _make_peer_registrar(self_rankinfo)
|
||||
peer_ri = make_rankinfo(instance_name="peer", instance_rank=1, layer_num_per_pp=[2])
|
||||
reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri)
|
||||
reg.unregister(peer_ri.instance_name, peer_ri.instance_rank)
|
||||
with pytest.raises(KeyError):
|
||||
reg.get_peer_rank_info("peer", 1)
|
||||
|
||||
|
||||
def test_peer_registrar_incompatible_peer_raises():
|
||||
self_rankinfo = make_rankinfo(instance_name="local")
|
||||
reg = _make_peer_registrar(self_rankinfo)
|
||||
peer_ri = make_rankinfo(instance_name="peer", instance_rank=3, is_mla=True)
|
||||
with pytest.raises(ValueError):
|
||||
reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri)
|
||||
|
||||
|
||||
def test_peer_registrar_self_rank_info_property():
|
||||
self_rankinfo = make_rankinfo(instance_name="local")
|
||||
reg = _make_peer_registrar(self_rankinfo)
|
||||
assert reg.self_rank_info == self_rankinfo
|
||||
|
||||
|
||||
def test_peer_registrar_get_kv_map_identity():
|
||||
self_rankinfo = make_rankinfo(instance_name="local")
|
||||
reg = _make_peer_registrar(self_rankinfo)
|
||||
peer_ri = make_rankinfo(instance_name="peer", instance_rank=1, layer_num_per_pp=[2])
|
||||
reg.register(peer_ri.instance_name, peer_ri.instance_rank, peer_ri)
|
||||
mapper = reg.get_kv_map(peer_ri)
|
||||
assert isinstance(mapper, IdentityMapper)
|
||||
|
||||
|
||||
def test_peer_registrar_get_kv_map_head_match():
|
||||
self_rankinfo = make_rankinfo(instance_name="local")
|
||||
reg = _make_peer_registrar(self_rankinfo)
|
||||
peer_ri = make_rankinfo(
|
||||
instance_name="peer",
|
||||
instance_rank=2,
|
||||
pp_size=2,
|
||||
pp_rank=1,
|
||||
layer_num_per_pp=[1, 1],
|
||||
tokens_per_block=16,
|
||||
dims_per_head=8,
|
||||
)
|
||||
mapper = reg.get_kv_map(peer_ri)
|
||||
assert isinstance(mapper, HeadMatchMapper)
|
||||
|
||||
|
||||
def test_peer_registrar_get_kv_map_head_mismatch():
|
||||
self_rankinfo = make_rankinfo(instance_name="local")
|
||||
reg = _make_peer_registrar(self_rankinfo)
|
||||
peer_ri = make_rankinfo(
|
||||
instance_name="peer",
|
||||
instance_rank=3,
|
||||
pp_size=1,
|
||||
pp_rank=0,
|
||||
tp_size=1,
|
||||
tp_rank=0,
|
||||
kv_heads_per_rank=4,
|
||||
tokens_per_block=16,
|
||||
dims_per_head=8,
|
||||
layer_num_per_pp=[2],
|
||||
)
|
||||
mapper = reg.get_kv_map(peer_ri)
|
||||
assert isinstance(mapper, HeadMismatchMapper)
|
||||
Loading…
Reference in New Issue
Block a user