mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 91bc17a32e into 38296a472b
This commit is contained in:
commit
883f9b6694
0
tensorrt_llm/_torch/disaggregation/__init__.py
Normal file
0
tensorrt_llm/_torch/disaggregation/__init__.py
Normal file
0
tensorrt_llm/_torch/disaggregation/base/__init__.py
Normal file
0
tensorrt_llm/_torch/disaggregation/base/__init__.py
Normal file
128
tensorrt_llm/_torch/disaggregation/base/agent.py
Normal file
128
tensorrt_llm/_torch/disaggregation/base/agent.py
Normal file
@ -0,0 +1,128 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from tensorrt_llm import logger
|
||||
|
||||
|
||||
# We deliberately use a non-enum data structure here. This choice ensures that
|
||||
# members are directly equivalent to the plain strings.
|
||||
class TransferOp:
|
||||
READ = "READ"
|
||||
WRITE = "WRITE"
|
||||
|
||||
|
||||
class MemoryType:
|
||||
DRAM = "DRAM"
|
||||
VRAM = "VRAM"
|
||||
BLK = "BLK"
|
||||
OBJ = "OBJ"
|
||||
FILE = "FILE"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryDesc:
|
||||
ptr: int
|
||||
size: int
|
||||
device_id: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryDescs:
|
||||
type: str
|
||||
descs: List[Union[Tuple[int, int, int], MemoryDesc]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferRequest:
|
||||
op: TransferOp
|
||||
src_descs: MemoryDescs
|
||||
dst_descs: MemoryDescs
|
||||
remote_name: str
|
||||
sync_message: str
|
||||
|
||||
|
||||
class TransferStatus(ABC):
|
||||
@abstractmethod
|
||||
def is_completed(self) -> bool: ...
|
||||
|
||||
@abstractmethod
|
||||
def wait(self, timeout: float | None = None) -> None: ...
|
||||
|
||||
|
||||
class BaseTransferAgent(ABC):
|
||||
@abstractmethod
|
||||
def register_memory(self, descs: MemoryDescs) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def deregister_memory(self, descs: MemoryDescs) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def load_remote_agent(self, name: str, agent_desc: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_local_agent_desc(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def invalidate_remote_agent(self, name: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus: ...
|
||||
|
||||
@abstractmethod
|
||||
def notify_sync_message(self, name: str, sync_message: str) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def check_remote_descs(self, name: str, memory_descs: List[int]) -> bool: ...
|
||||
|
||||
|
||||
# RegMemoryDescs is Python-only (used for registration with name field)
|
||||
@dataclass
|
||||
class RegMemoryDescs:
|
||||
type: str
|
||||
descs: List[Tuple[int, int, int, str]] # (ptr, size, device_id, name)
|
||||
|
||||
|
||||
def _force_py_nixl_kv_transfer() -> bool:
|
||||
res = os.getenv("TRTLLM_USE_PY_NIXL_KVCACHE", "0") == "1"
|
||||
if res:
|
||||
logger.info("Forcing use of pure Python NIXL KV Transfer Agent implementation.")
|
||||
return res
|
||||
|
||||
|
||||
def _try_load_cpp_binding():
|
||||
try:
|
||||
import tensorrt_llm.tensorrt_llm_transfer_agent_binding as _cpp_binding
|
||||
|
||||
required_attributes = [
|
||||
"MemoryType",
|
||||
"TransferOp",
|
||||
"MemoryDesc",
|
||||
"MemoryDescs",
|
||||
"TransferRequest",
|
||||
"TransferStatus",
|
||||
"BaseTransferAgent",
|
||||
]
|
||||
if all(hasattr(_cpp_binding, attr) for attr in required_attributes):
|
||||
return _cpp_binding
|
||||
except ImportError:
|
||||
logger.info("tensorrt_llm_transfer_agent_binding module not found.")
|
||||
return None
|
||||
|
||||
|
||||
_cpp_binding = _try_load_cpp_binding()
|
||||
|
||||
if _cpp_binding and not _force_py_nixl_kv_transfer():
|
||||
MemoryType = _cpp_binding.MemoryType
|
||||
TransferOp = _cpp_binding.TransferOp
|
||||
MemoryDesc = _cpp_binding.MemoryDesc
|
||||
MemoryDescs = _cpp_binding.MemoryDescs
|
||||
TransferRequest = _cpp_binding.TransferRequest
|
||||
TransferStatus = _cpp_binding.TransferStatus
|
||||
BaseTransferAgent = _cpp_binding.BaseTransferAgent
|
||||
logger.info("Using Pybind transfer agent binding for Transfer Agent implementation.")
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to import Pybind transfer agent binding, using pure Python implementation."
|
||||
)
|
||||
144
tensorrt_llm/_torch/disaggregation/base/kv_transfer.py
Normal file
144
tensorrt_llm/_torch/disaggregation/base/kv_transfer.py
Normal file
@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from tensorrt_llm import DisaggregatedParams
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVSlice:
|
||||
"""Supports transmitting only part of the request cache, e.g, chunks or layers."""
|
||||
|
||||
start_token_idx: Optional[int] = None
|
||||
end_token_idx: Optional[int] = None
|
||||
start_layer: Optional[int] = None
|
||||
end_layer: Optional[int] = None
|
||||
blocks: List[int] = field(default_factory=list)
|
||||
is_last_slice: bool = False
|
||||
|
||||
|
||||
class SessionStatus(Enum):
|
||||
"""Status of a transfer session."""
|
||||
|
||||
INIT = "INIT"
|
||||
READY = "READY"
|
||||
TRANSFERRING = "TRANSFERRING"
|
||||
TRANSFERRED = "TRANSFERRED"
|
||||
AUX_TRANSFERRED = "AUX_TRANSFERRED"
|
||||
COMPLETED = "COMPLETED"
|
||||
CANCELED = "CANCELED"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
TaskIdType = int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
"""State of a transfer session."""
|
||||
|
||||
status: SessionStatus
|
||||
finished_tasks: List[TaskIdType]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionArgsBase:
|
||||
"""Base arguments for transfer sessions."""
|
||||
|
||||
params: DisaggregatedParams
|
||||
|
||||
|
||||
class SenderBase(ABC):
|
||||
"""Base class for sending KV cache data."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class ReceiverBase(ABC):
|
||||
"""Base class for receiving KV cache data."""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class TxSessionBase(ABC):
|
||||
def __init__(self, sender: SenderBase, args: SessionArgsBase):
|
||||
"""
|
||||
Initializes the transmission session.
|
||||
:param sender: The sender instance responsible for sending data.
|
||||
:param args: The session arguments.
|
||||
"""
|
||||
self._base_args = args
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def state(self) -> SessionState:
|
||||
"""
|
||||
Returns the current state of the session.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def poll_task(self, id: TaskIdType) -> SessionStatus:
|
||||
"""
|
||||
Polls the status of a specific task by its ID.
|
||||
:param id: The task ID to poll.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def send(self, slice: KVSlice) -> TaskIdType:
|
||||
"""
|
||||
Sends a slice of KV cache data and returns the task ID.
|
||||
:param slice: The KV slice to send.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def exception(self) -> Optional[Exception]:
|
||||
"""
|
||||
Returns any exception that occurred during the session.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class RxSessionBase(ABC):
|
||||
def __init__(self, receiver: ReceiverBase, args: SessionArgsBase):
|
||||
"""
|
||||
Initializes the reception session.
|
||||
:param receiver: The receiver instance responsible for receiving data.
|
||||
"""
|
||||
self._base_args = args
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def state(self) -> SessionState:
|
||||
"""
|
||||
Returns the current state of the session.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def poll_task(self, id: TaskIdType) -> SessionStatus:
|
||||
"""
|
||||
Polls the status of a specific task by its ID.
|
||||
:param id: The task ID to poll.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def receive(self, slice: KVSlice) -> TaskIdType:
|
||||
"""
|
||||
Receives a slice of KV cache data and returns the task ID.
|
||||
:param slice: The KV slice to receive.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def exception(self) -> Optional[Exception]:
|
||||
"""Returns any exception that occurred during the session."""
|
||||
...
|
||||
228
tensorrt_llm/_torch/disaggregation/native/messenger.py
Normal file
228
tensorrt_llm/_torch/disaggregation/native/messenger.py
Normal file
@ -0,0 +1,228 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Callable, Optional
|
||||
|
||||
import zmq
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip
|
||||
|
||||
|
||||
class MessengerInterface(ABC):
|
||||
"""
|
||||
Abstract base class for messenger implementations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self):
|
||||
"""
|
||||
Start the messenger service.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def send(self, messages: list[bytes], recipient: Optional[bytes] = None):
|
||||
"""
|
||||
Send messages to a recipient.
|
||||
:param messages: List of byte messages to send.
|
||||
:param recipient: Optional recipient identifier.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def send_encoded(self, *messages, encoding: str = "ascii"):
|
||||
"""
|
||||
Send messages after encoding them.
|
||||
:param messages: Messages to send.
|
||||
:param encoding: Encoding format.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def receive(self) -> list[bytes]:
|
||||
"""
|
||||
Receive messages.
|
||||
:return: List of byte messages received.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def start_listener(self, on_message: Callable[[list[bytes]], None]):
|
||||
"""
|
||||
Start a listener thread to handle incoming messages.
|
||||
:param on_message: Callback function to process received messages.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def stop(self):
|
||||
"""
|
||||
Stop the messenger service.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def endpoint(self) -> str:
|
||||
"""
|
||||
Get the endpoint of the messenger.
|
||||
:return: Endpoint string.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def decode_message(
|
||||
message: list[bytes], encoding: str = "ascii", err_mode: str = "strict"
|
||||
) -> tuple:
|
||||
if not isinstance(message, list) or not all(isinstance(m, bytes) for m in message):
|
||||
raise ValueError("Input must be a list of bytes")
|
||||
return tuple(m.decode(encoding, errors=err_mode) for m in message)
|
||||
|
||||
|
||||
class ZMQMessenger(MessengerInterface):
|
||||
SOCKET_MODES = {
|
||||
"ROUTER": zmq.ROUTER, # Handles multiple connections and routes messages by address.
|
||||
"DEALER": zmq.DEALER, # Load balances outgoing messages and receives replies fairly.
|
||||
"REQ": zmq.REQ, # Sends requests and waits for replies (synchronous).
|
||||
"REP": zmq.REP, # Receives requests and sends replies (synchronous).
|
||||
}
|
||||
|
||||
def __init__(self, mode: str, endpoint: Optional[str] = None):
|
||||
self._context = zmq.Context()
|
||||
self._mode = mode
|
||||
self._socket = self._context.socket(self.SOCKET_MODES[mode])
|
||||
self._endpoint: Optional[str] = None
|
||||
self._lock = Lock()
|
||||
self._closed = False
|
||||
self._stop_event = Event()
|
||||
self._listener_thread: Optional[Thread] = None
|
||||
self._initialize_control_sockets()
|
||||
|
||||
if endpoint is None:
|
||||
endpoint = f"tcp://{get_local_ip()}:*"
|
||||
|
||||
if mode in ["ROUTER", "REP"]:
|
||||
self._socket.bind(endpoint)
|
||||
self._endpoint = self._socket.getsockopt_string(zmq.LAST_ENDPOINT)
|
||||
elif mode in ["DEALER", "REQ"]:
|
||||
self._socket.connect(endpoint)
|
||||
self._endpoint = endpoint
|
||||
|
||||
logger.info(f"Initialized ZMQMessenger(mode={mode}, endpoint={self._endpoint})")
|
||||
|
||||
def _initialize_control_sockets(self):
|
||||
self._control_socket = self._context.socket(zmq.PAIR)
|
||||
self._internal_socket = self._context.socket(zmq.PAIR)
|
||||
inproc_endpoint = "inproc://stop_listener"
|
||||
self._control_socket.bind(inproc_endpoint)
|
||||
self._internal_socket.connect(inproc_endpoint)
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
def send(self, messages: list[bytes], recipient: Optional[bytes] = None):
|
||||
if recipient:
|
||||
self._socket.send_multipart([recipient] + messages)
|
||||
else:
|
||||
self._socket.send_multipart(messages)
|
||||
|
||||
def send_encoded(self, *messages, encoding: str = "ascii"):
|
||||
encoded_messages = [str(message).encode(encoding) for message in messages]
|
||||
self.send(encoded_messages)
|
||||
|
||||
def receive(self) -> list[bytes]:
|
||||
return self._socket.recv_multipart()
|
||||
|
||||
def start_listener(
|
||||
self,
|
||||
on_message: Callable[[list[bytes]], None],
|
||||
on_error: Optional[Callable[[Exception], None]] = None,
|
||||
):
|
||||
assert self._mode in ["ROUTER", "REP"], (
|
||||
"Listener can only be started in ROUTER or REP modes"
|
||||
)
|
||||
if self._listener_thread and self._listener_thread.is_alive():
|
||||
raise RuntimeError("Listener already running")
|
||||
|
||||
def listener():
|
||||
poller = zmq.Poller()
|
||||
poller.register(self._socket, zmq.POLLIN)
|
||||
poller.register(self._control_socket, zmq.POLLIN)
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
events = dict(poller.poll(timeout=100))
|
||||
try:
|
||||
if self._control_socket in events:
|
||||
self._stop_event.set()
|
||||
elif self._socket in events:
|
||||
messages = self.receive()
|
||||
try:
|
||||
persist = on_message(messages)
|
||||
if persist is False:
|
||||
self._stop_event.set()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_message callback: {e}")
|
||||
if on_error:
|
||||
on_error(e)
|
||||
else:
|
||||
self._stop_event.set()
|
||||
except zmq.ZMQError as e:
|
||||
logger.error(f"ZMQ Error in listener: {e}")
|
||||
if on_error:
|
||||
on_error(e)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listener: {e}")
|
||||
if on_error:
|
||||
on_error(e)
|
||||
break
|
||||
|
||||
self._stop_event.set()
|
||||
|
||||
self._listener_thread = Thread(target=listener, daemon=True)
|
||||
self._listener_thread.start()
|
||||
|
||||
def stop(self, timeout=5):
|
||||
def _close_socket(socket: zmq.Socket):
|
||||
try:
|
||||
if not socket.closed:
|
||||
socket.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing socket: {e}")
|
||||
|
||||
with self._lock:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
logger.debug("Stopping ZMQMessenger...")
|
||||
|
||||
self._stop_event.set()
|
||||
self._internal_socket.send(b"STOP")
|
||||
if self._listener_thread:
|
||||
self._listener_thread.join(timeout)
|
||||
if self._listener_thread.is_alive():
|
||||
logger.warning("Listener thread did not terminate within timeout")
|
||||
|
||||
_close_socket(self._socket)
|
||||
_close_socket(self._internal_socket)
|
||||
_close_socket(self._control_socket)
|
||||
|
||||
try:
|
||||
if self._context.closed:
|
||||
self._context.term()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating ZMQ context: {e}")
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
assert self._endpoint is not None
|
||||
return self._endpoint
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
34
tensorrt_llm/_torch/disaggregation/native/utils.py
Normal file
34
tensorrt_llm/_torch/disaggregation/native/utils.py
Normal file
@ -0,0 +1,34 @@
|
||||
def get_local_ip() -> str:
|
||||
try:
|
||||
import socket
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(("8.8.8.8", 80))
|
||||
ip = s.getsockname()[0]
|
||||
if not ip.startswith("127."):
|
||||
return ip
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import netifaces
|
||||
|
||||
for iface in netifaces.interfaces():
|
||||
addrs = netifaces.ifaddresses(iface)
|
||||
if netifaces.AF_INET in addrs:
|
||||
for addr in addrs[netifaces.AF_INET]:
|
||||
ip = addr.get("addr", "")
|
||||
if not ip.startswith("127.") and not ip.startswith("169.254"):
|
||||
return ip
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
hostname = socket.gethostname()
|
||||
ip = socket.gethostbyname(hostname)
|
||||
if not ip.startswith("127."):
|
||||
return ip
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return "127.0.0.1"
|
||||
0
tensorrt_llm/_torch/disaggregation/nixl/__init__.py
Normal file
0
tensorrt_llm/_torch/disaggregation/nixl/__init__.py
Normal file
132
tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py
Normal file
132
tensorrt_llm/_torch/disaggregation/nixl/_agent_cpp.py
Normal file
@ -0,0 +1,132 @@
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import ( # noqa: E402
|
||||
AgentDesc,
|
||||
BaseAgentConfig,
|
||||
MemoryDescs,
|
||||
MemoryType,
|
||||
TransferState,
|
||||
)
|
||||
from tensorrt_llm.tensorrt_llm_transfer_agent_binding import (
|
||||
NixlTransferAgent as CppNixlTransferAgent,
|
||||
)
|
||||
|
||||
from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus
|
||||
|
||||
|
||||
class BindingsNixlTransferStatus(TransferStatus):
|
||||
"""TransferStatus wrapper using C++ bindings with GIL release."""
|
||||
|
||||
def __init__(self, cpp_status):
|
||||
self._cpp_status = cpp_status
|
||||
|
||||
def is_completed(self) -> bool:
|
||||
"""Check if transfer is completed (releases GIL)."""
|
||||
return self._cpp_status.is_completed()
|
||||
|
||||
@nvtx_range("BindingsNixlTransferStatus.wait")
|
||||
def wait(self) -> bool:
|
||||
"""Wait for transfer to complete (releases GIL)."""
|
||||
return self._cpp_status.wait() == TransferState.SUCCESS
|
||||
|
||||
|
||||
class BindingsNixlTransferAgent(BaseTransferAgent):
|
||||
"""NixlTransferAgent using C++ bindings with GIL release support.
|
||||
|
||||
This implementation uses the standalone nixl_bindings C++ module which releases
|
||||
the GIL during blocking operations like wait().
|
||||
|
||||
The nixl_bindings module is independent from the main trtllm bindings,
|
||||
so trtllm can function normally even without NIXL.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, use_prog_thread: bool = True, num_workers: int = 1):
|
||||
config = BaseAgentConfig(
|
||||
name,
|
||||
use_prog_thread,
|
||||
multi_thread=False,
|
||||
use_listen_thread=False,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
self._cpp_agent = CppNixlTransferAgent(config)
|
||||
self.name = name
|
||||
|
||||
def register_memory(self, descs: RegMemoryDescs):
|
||||
"""Register memory regions."""
|
||||
cpp_descs = self._convert_reg_memory_descs(descs)
|
||||
self._cpp_agent.register_memory(cpp_descs)
|
||||
|
||||
def deregister_memory(self, descs: RegMemoryDescs):
|
||||
"""Deregister memory regions."""
|
||||
cpp_descs = self._convert_reg_memory_descs(descs)
|
||||
self._cpp_agent.deregister_memory(cpp_descs)
|
||||
|
||||
def load_remote_agent(self, name: str, agent_desc: bytes):
|
||||
"""Load a remote agent by its descriptor (bytes)."""
|
||||
# AgentDesc expects std::string which can hold binary data
|
||||
desc_str = agent_desc if isinstance(agent_desc, bytes) else agent_desc.encode()
|
||||
cpp_desc = AgentDesc(desc_str)
|
||||
self._cpp_agent.load_remote_agent(name, cpp_desc)
|
||||
|
||||
def load_remote_agent_by_connection(self, name: str, connection_info: str):
|
||||
"""Load a remote agent by connection info."""
|
||||
self._cpp_agent.load_remote_agent_by_connection(name, connection_info)
|
||||
|
||||
def get_local_agent_desc(self) -> bytes:
|
||||
"""Get the local agent descriptor as bytes."""
|
||||
agent_desc = self._cpp_agent.get_local_agent_desc()
|
||||
return agent_desc.backend_agent_desc # Returns bytes
|
||||
|
||||
def get_local_connection_info(self) -> str:
|
||||
"""Get the local connection info."""
|
||||
return self._cpp_agent.get_local_connection_info()
|
||||
|
||||
def invalidate_remote_agent(self, name: str):
|
||||
"""Invalidate a remote agent."""
|
||||
self._cpp_agent.invalidate_remote_agent(name)
|
||||
|
||||
def check_remote_descs(self, name: str, memory_descs: MemoryDescs) -> bool:
|
||||
"""Check if remote descriptors are available.
|
||||
|
||||
memory_descs should be C++ MemoryDescs type.
|
||||
"""
|
||||
return self._cpp_agent.check_remote_descs(name, memory_descs)
|
||||
|
||||
def notify_sync_message(self, name: str, sync_message: str):
|
||||
"""Send a sync message to a remote agent."""
|
||||
self._cpp_agent.notify_sync_message(name, sync_message)
|
||||
|
||||
def get_notified_sync_messages(self):
|
||||
"""Get notified sync messages."""
|
||||
return self._cpp_agent.get_notified_sync_messages()
|
||||
|
||||
@nvtx_range("BindingsNixlTransferAgent.submit_transfer_requests")
|
||||
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
|
||||
"""Submit transfer requests and return status.
|
||||
|
||||
request should be a C++ TransferRequest (from tensorrt_llm_transfer_agent_binding).
|
||||
"""
|
||||
cpp_status = self._cpp_agent.submit_transfer_requests(request)
|
||||
return BindingsNixlTransferStatus(cpp_status)
|
||||
|
||||
def _convert_reg_memory_descs(self, descs: RegMemoryDescs) -> "MemoryDescs":
|
||||
"""Convert Python RegMemoryDescs to C++ MemoryDescs.
|
||||
|
||||
RegMemoryDescs.descs is List[Tuple[int, int, int, str]] = (ptr, size, device_id, name)
|
||||
Extract first 3 elements for C++ batch constructor.
|
||||
"""
|
||||
mem_type = self._convert_memory_type(descs.type)
|
||||
# Extract (ptr, size, device_id) from 4-tuple, discard name
|
||||
tuples = [(d[0], d[1], d[2]) for d in descs.descs]
|
||||
return MemoryDescs(mem_type, tuples)
|
||||
|
||||
def _convert_memory_type(self, py_type: str) -> "MemoryType":
|
||||
"""Convert Python memory type string to C++ MemoryType."""
|
||||
type_map = {
|
||||
"DRAM": MemoryType.DRAM,
|
||||
"VRAM": MemoryType.VRAM,
|
||||
"GPU": MemoryType.VRAM,
|
||||
"BLK": MemoryType.BLK,
|
||||
"OBJ": MemoryType.OBJ,
|
||||
"FILE": MemoryType.FILE,
|
||||
}
|
||||
return type_map.get(py_type.upper(), MemoryType.VRAM)
|
||||
99
tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py
Normal file
99
tensorrt_llm/_torch/disaggregation/nixl/_agent_py.py
Normal file
@ -0,0 +1,99 @@
|
||||
import time
|
||||
|
||||
from nixl import nixl_agent, nixl_agent_config, nixl_xfer_handle # noqa: E402
|
||||
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
|
||||
# Import base classes for type compatibility
|
||||
from ..base.agent import BaseTransferAgent, RegMemoryDescs, TransferRequest, TransferStatus
|
||||
|
||||
|
||||
class NixlTransferStatus(TransferStatus):
|
||||
"""TransferStatus using Python nixl library."""
|
||||
|
||||
def __init__(self, agent: nixl_agent, handle: nixl_xfer_handle):
|
||||
self.agent = agent
|
||||
self.handle = handle
|
||||
|
||||
def is_completed(self):
|
||||
status = self.agent.check_xfer_state(self.handle)
|
||||
return status == "DONE"
|
||||
|
||||
def wait(self):
|
||||
status = "PROC"
|
||||
sleep_time = 0.0001 # 0.1ms
|
||||
max_sleep_time = 0.01 # 10ms
|
||||
while status == "PROC":
|
||||
status = self.agent.check_xfer_state(self.handle)
|
||||
if status == "ERROR":
|
||||
return False # transfer failed
|
||||
# sleep(0.1)
|
||||
# sleep to release GIL
|
||||
time.sleep(sleep_time)
|
||||
sleep_time = min(sleep_time * 2, max_sleep_time)
|
||||
return True
|
||||
|
||||
|
||||
class NixlTransferAgent(BaseTransferAgent):
|
||||
"""NixlTransferAgent using Python nixl library."""
|
||||
|
||||
def __init__(self, name: str, use_prog_thread: bool = True, num_workers: int = 1):
|
||||
"""
|
||||
Initialize NixlTransferAgent.
|
||||
:param name: Name of the agent.
|
||||
:param use_prog_thread: Whether to enable the progress thread, if available.
|
||||
:param num_workers: Specify number of threads for the supported multi-threaded backends.
|
||||
"""
|
||||
self.name = name
|
||||
self.backends = ["UCX"]
|
||||
agent_config = nixl_agent_config(
|
||||
enable_prog_thread=use_prog_thread, backends=self.backends, num_threads=num_workers
|
||||
)
|
||||
self.agent = nixl_agent(name, agent_config)
|
||||
|
||||
def register_memory(self, descs: RegMemoryDescs):
|
||||
if isinstance(descs.descs[0], tuple):
|
||||
assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}"
|
||||
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
|
||||
assert reg_descs is not None, "Failed to get reg_descs"
|
||||
self.agent.register_memory(reg_descs)
|
||||
|
||||
def deregister_memory(self, descs: RegMemoryDescs):
|
||||
if isinstance(descs.descs[0], tuple):
|
||||
assert len(descs.descs[0]) == 4, f"Expected 4 elements per desc, got {descs.descs[0]}"
|
||||
reg_descs = self.agent.get_reg_descs(descs.descs, descs.type)
|
||||
assert reg_descs is not None, "Failed to get reg_descs"
|
||||
self.agent.deregister_memory(reg_descs)
|
||||
|
||||
def load_remote_agent(self, name: str, agent_desc: bytes):
|
||||
self.agent.add_remote_agent(agent_desc)
|
||||
|
||||
def get_local_agent_desc(self):
|
||||
return self.agent.get_agent_metadata()
|
||||
|
||||
def invalidate_remote_agent(self, name: str):
|
||||
self.agent.remove_remote_agent(name)
|
||||
|
||||
def check_remote_descs(self, name: str, memory_descs: list[int]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def notify_sync_message(self, name: str, sync_message: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@nvtx_range("NixlTransferAgent.submit_transfer_requests")
|
||||
def submit_transfer_requests(self, request: TransferRequest) -> TransferStatus:
|
||||
src_xfer_descs = self.agent.get_xfer_descs(request.src_descs.descs, request.src_descs.type)
|
||||
dst_xfer_descs = self.agent.get_xfer_descs(request.dst_descs.descs, request.dst_descs.type)
|
||||
assert src_xfer_descs is not None, "Failed to get src_xfer_descs"
|
||||
assert dst_xfer_descs is not None, "Failed to get dst_xfer_descs"
|
||||
sync_message = "" if request.sync_message is None else request.sync_message
|
||||
handle = self.agent.initialize_xfer(
|
||||
request.op,
|
||||
src_xfer_descs,
|
||||
dst_xfer_descs,
|
||||
request.remote_name,
|
||||
sync_message,
|
||||
)
|
||||
status = self.agent.transfer(handle)
|
||||
assert status != "ERROR"
|
||||
return NixlTransferStatus(self.agent, handle)
|
||||
46
tensorrt_llm/_torch/disaggregation/nixl/agent.py
Normal file
46
tensorrt_llm/_torch/disaggregation/nixl/agent.py
Normal file
@ -0,0 +1,46 @@
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..base.agent import _force_py_nixl_kv_transfer
|
||||
|
||||
"""NIXL Transfer Agent implementations.
|
||||
|
||||
This module provides two implementations:
|
||||
1. BindingsNixlTransferAgent - Uses the standalone nixl_bindings C++ module with GIL release support
|
||||
2. NixlTransferAgent - Uses the Python nixl library directly (fallback)
|
||||
|
||||
The standalone nixl_bindings module is separate from the main trtllm bindings,
|
||||
so trtllm can still function normally even without NIXL dependencies.
|
||||
"""
|
||||
|
||||
|
||||
def _load_agent(module_name, required_attributes):
|
||||
try:
|
||||
module = __import__(module_name, fromlist=required_attributes, level=0)
|
||||
if all(hasattr(module, attr) for attr in required_attributes):
|
||||
return module
|
||||
except ImportError:
|
||||
logger.info("Failed to import module: %s", module_name)
|
||||
return None
|
||||
|
||||
|
||||
_py_agent = _load_agent(
|
||||
module_name="tensorrt_llm._torch.disaggregation.nixl._agent_py",
|
||||
required_attributes=["NixlTransferAgent", "NixlTransferStatus"],
|
||||
)
|
||||
|
||||
_cpp_agent = _load_agent(
|
||||
module_name="tensorrt_llm._torch.disaggregation.nixl._agent_cpp",
|
||||
required_attributes=["BindingsNixlTransferAgent", "BindingsNixlTransferStatus"],
|
||||
)
|
||||
|
||||
# Determine which Transfer Agent implementation to use
|
||||
if _cpp_agent and not _force_py_nixl_kv_transfer():
|
||||
NixlTransferStatus = _cpp_agent.BindingsNixlTransferStatus
|
||||
NixlTransferAgent = _cpp_agent.BindingsNixlTransferAgent
|
||||
logger.info("Using C++ NIXL Transfer Agent implementation.")
|
||||
elif _py_agent:
|
||||
NixlTransferStatus = _py_agent.NixlTransferStatus
|
||||
NixlTransferAgent = _py_agent.NixlTransferAgent
|
||||
logger.info("Using Python NIXL Transfer Agent implementation.")
|
||||
else:
|
||||
raise ImportError("Both C++ and Python NIXL Transfer Agents failed to load.")
|
||||
156
tests/unittest/disaggregated/test_agent.py
Normal file
156
tests/unittest/disaggregated/test_agent.py
Normal file
@ -0,0 +1,156 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._torch.disaggregation.base.agent import (
|
||||
MemoryDescs,
|
||||
MemoryType,
|
||||
RegMemoryDescs,
|
||||
TransferOp,
|
||||
TransferRequest,
|
||||
)
|
||||
from tensorrt_llm._torch.disaggregation.nixl.agent import NixlTransferAgent
|
||||
|
||||
|
||||
def _convert_to_memory_descs(reg_descs: RegMemoryDescs) -> MemoryDescs:
|
||||
tuples = [(ptr, size, device_id) for (ptr, size, device_id, _) in reg_descs.descs]
|
||||
|
||||
def _convert_memory_type(py_type: str) -> MemoryType:
|
||||
"""Convert Python memory type string to C++ MemoryType."""
|
||||
type_map = {
|
||||
"DRAM": MemoryType.DRAM,
|
||||
"VRAM": MemoryType.VRAM,
|
||||
"GPU": MemoryType.VRAM,
|
||||
"BLK": MemoryType.BLK,
|
||||
"OBJ": MemoryType.OBJ,
|
||||
"FILE": MemoryType.FILE,
|
||||
}
|
||||
return type_map.get(py_type.upper(), MemoryType.VRAM)
|
||||
|
||||
return MemoryDescs(_convert_memory_type(reg_descs.type), tuples)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryManager:
|
||||
allocated_memory: list[torch.Tensor] = field(default_factory=list)
|
||||
|
||||
def allocate_memory(
|
||||
self, size: int, name: str, memory_type=MemoryType.VRAM, device_id: int = 0
|
||||
) -> RegMemoryDescs:
|
||||
device = torch.device(f"cuda:{device_id}" if memory_type == MemoryType.VRAM else "cpu")
|
||||
|
||||
# Allocate memory block using torch.Tensor and track it
|
||||
block = torch.zeros(size, dtype=torch.uint8, device=device)
|
||||
self.allocated_memory.append(block)
|
||||
|
||||
# Return RegMemoryDescs with position arguments
|
||||
memory_descs = RegMemoryDescs(
|
||||
type=memory_type, descs=[(block.data_ptr(), block.numel(), device_id, name)]
|
||||
)
|
||||
return memory_descs
|
||||
|
||||
def clear_memory(self):
|
||||
"""Clear all tracked memory blocks."""
|
||||
self.allocated_memory.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_manager():
|
||||
return MemoryManager()
|
||||
|
||||
|
||||
@pytest.fixture(params=[256, 512])
|
||||
def memory_size(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=["DRAM", "VRAM"])
|
||||
def memory_type(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def alloc(memory_manager, memory_size, memory_type):
|
||||
"""Allocate memory for source and destination, based on the memory_size and memory_type parameters."""
|
||||
assert memory_size > 0, "Memory size must be a positive integer."
|
||||
src_descs = memory_manager.allocate_memory(
|
||||
size=memory_size, name="src_mem", memory_type=memory_type
|
||||
)
|
||||
dst_descs = memory_manager.allocate_memory(
|
||||
size=memory_size, name="dst_mem", memory_type=memory_type
|
||||
)
|
||||
return src_descs, dst_descs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transfer_agent_src():
|
||||
return NixlTransferAgent(name="src_agent")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transfer_agent_dst():
|
||||
return NixlTransferAgent(name="dst_agent")
|
||||
|
||||
|
||||
def test_transfer_between_agents(
|
||||
transfer_agent_src,
|
||||
transfer_agent_dst,
|
||||
memory_manager,
|
||||
alloc,
|
||||
memory_size,
|
||||
memory_type,
|
||||
):
|
||||
"""End-to-end test of data transfer between two agents with parameterized memory sizes and types."""
|
||||
# Debug log the parameters being tested
|
||||
logger.info(f"Testing with memory_size={memory_size}, memory_type={memory_type}")
|
||||
|
||||
# Unpack source and destination memory descriptions
|
||||
memory_descs_src, memory_descs_dst = alloc
|
||||
|
||||
# Fill source memory with sequential data for validation
|
||||
src_data = memory_manager.allocated_memory[0]
|
||||
assert memory_size > 0, "Memory size must be positive."
|
||||
tensor = torch.arange(memory_size, dtype=torch.uint8) % 10
|
||||
src_data.copy_(tensor)
|
||||
|
||||
# Register memory with source and destination agents
|
||||
transfer_agent_src.register_memory(memory_descs_src)
|
||||
transfer_agent_dst.register_memory(memory_descs_dst)
|
||||
|
||||
src_agent_desc = transfer_agent_src.get_local_agent_desc()
|
||||
transfer_agent_dst.load_remote_agent("src_agent", src_agent_desc)
|
||||
|
||||
dst_agent_desc = transfer_agent_dst.get_local_agent_desc()
|
||||
transfer_agent_src.load_remote_agent("dst_agent", dst_agent_desc)
|
||||
|
||||
# Create and submit the transfer request
|
||||
transfer_request = TransferRequest(
|
||||
op=TransferOp.WRITE,
|
||||
src_descs=_convert_to_memory_descs(memory_descs_src),
|
||||
dst_descs=_convert_to_memory_descs(memory_descs_dst),
|
||||
remote_name="dst_agent",
|
||||
sync_message=None,
|
||||
)
|
||||
transfer_status = transfer_agent_src.submit_transfer_requests(transfer_request)
|
||||
transfer_status.wait()
|
||||
|
||||
# Validate transfer completion
|
||||
assert transfer_status.is_completed(), "Transfer did not complete successfully."
|
||||
|
||||
# Validate that the destination data matches the source data
|
||||
dst_data = memory_manager.allocated_memory[1]
|
||||
assert torch.equal(dst_data, src_data), "Destination data does not match source data."
|
||||
|
||||
# Clean up by deregistering memory and clearing allocations
|
||||
transfer_agent_src.deregister_memory(memory_descs_src)
|
||||
transfer_agent_dst.deregister_memory(memory_descs_dst)
|
||||
memory_manager.clear_memory()
|
||||
|
||||
transfer_agent_src.invalidate_remote_agent("dst_agent")
|
||||
transfer_agent_dst.invalidate_remote_agent("src_agent")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
127
tests/unittest/disaggregated/test_messenger.py
Normal file
127
tests/unittest/disaggregated/test_messenger.py
Normal file
@ -0,0 +1,127 @@
|
||||
import socket
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from tensorrt_llm._torch.disaggregation.native.messenger import ZMQMessenger, decode_message
|
||||
from tensorrt_llm._torch.disaggregation.native.utils import get_local_ip
|
||||
|
||||
TEST_CASES = [
|
||||
{
|
||||
"name": "valid_message",
|
||||
"message": [b"hello", b"world"],
|
||||
"encoding": "utf-8",
|
||||
"err_mode": "strict",
|
||||
"expected": ("hello", "world"),
|
||||
"raises": None,
|
||||
},
|
||||
{
|
||||
"name": "invalid_input",
|
||||
"message": ["hello", b"world"],
|
||||
"encoding": "utf-8",
|
||||
"err_mode": "strict",
|
||||
"expected": None,
|
||||
"raises": ValueError,
|
||||
},
|
||||
{
|
||||
"name": "decoding_error",
|
||||
"message": [b"\xff"],
|
||||
"encoding": "utf-8",
|
||||
"err_mode": "strict",
|
||||
"expected": None,
|
||||
"raises": UnicodeDecodeError,
|
||||
},
|
||||
{
|
||||
"name": "decoding_with_ignore",
|
||||
"message": [b"\xff"],
|
||||
"encoding": "utf-8",
|
||||
"err_mode": "ignore",
|
||||
"expected": ("",),
|
||||
"raises": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TestDecodeMessage(unittest.TestCase):
|
||||
@parameterized.expand([(case["name"], case) for case in TEST_CASES])
|
||||
def test_decode_message(self, name, case):
|
||||
message = case["message"]
|
||||
encoding = case["encoding"]
|
||||
err_mode = case["err_mode"]
|
||||
expected = case["expected"]
|
||||
raises = case["raises"]
|
||||
|
||||
if raises:
|
||||
with self.assertRaises(raises):
|
||||
decode_message(message, encoding=encoding, err_mode=err_mode)
|
||||
else:
|
||||
decoded = decode_message(message, encoding=encoding, err_mode=err_mode)
|
||||
self.assertEqual(decoded, expected)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dynamic_endpoint():
|
||||
"""Fixture to dynamically generate an available endpoint with a free port."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0)) # Bind to an available port provided by the OS
|
||||
port = s.getsockname()[1]
|
||||
return f"tcp://{get_local_ip()}:{port}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_messenger_pair(dynamic_endpoint):
|
||||
def _create_messenger_pair(mode1, mode2):
|
||||
messenger1 = ZMQMessenger(
|
||||
mode1, endpoint=dynamic_endpoint if mode1 in ["ROUTER", "REP"] else None
|
||||
)
|
||||
messenger2 = ZMQMessenger(
|
||||
mode2, endpoint=dynamic_endpoint if mode2 in ["DEALER", "REQ"] else None
|
||||
)
|
||||
return messenger1, messenger2
|
||||
|
||||
yield _create_messenger_pair
|
||||
|
||||
|
||||
def test_router_dealer(create_messenger_pair):
|
||||
"""Test ROUTER and DEALER communication."""
|
||||
router, dealer = create_messenger_pair("ROUTER", "DEALER")
|
||||
|
||||
received_messages = []
|
||||
|
||||
def on_message(messages):
|
||||
received_messages.extend(messages)
|
||||
|
||||
router.start_listener(on_message)
|
||||
|
||||
dealer.send([b"Hello, ROUTER!"])
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(received_messages) > 0
|
||||
assert b"Hello, ROUTER!" in received_messages
|
||||
|
||||
router.stop()
|
||||
dealer.stop()
|
||||
|
||||
|
||||
def test_req_rep(create_messenger_pair):
|
||||
"""Test REQ and REP communication."""
|
||||
rep, req = create_messenger_pair("REP", "REQ")
|
||||
|
||||
def on_message(messages):
|
||||
rep.send(messages)
|
||||
|
||||
rep.start_listener(on_message)
|
||||
|
||||
req.send([b"Hello, REP!"])
|
||||
response = req.receive()
|
||||
assert response == [b"Hello, REP!"]
|
||||
|
||||
req.stop()
|
||||
rep.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user