TensorRT-LLMs/tensorrt_llm/_torch/pipeline_interface.py
Kaiyu Xie 2631f21089
Update (#2978)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2025-03-23 16:39:35 +08:00

78 lines
2.8 KiB
Python

from typing import Optional, Union
import torch
from .distributed import PPComm
class PipelineInterface:
"""
A container class for passing intermediate tensors between pipeline parallel ranks.
It contains two intermediate tensors: [hidden_states, residual], supporting:
- Dict access: pp['hidden_states'], pp['residual']
- Unpacking: hidden, residual = pp
- PP communication: pp.send(), pp.recv()
- Slicing: pp[start:end]
Note: When using this interface in pp, the packing/unpacking and send/recv
operations must be used symmetrically within stage and between succsive ranks.
"""
_pp_comm = None
def __init__(self,
hidden_states: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
self.hidden_states = hidden_states
self.residual = residual
self.tag = 1234
@classmethod
def init_pp_comm(cls, mapping):
"""Initialize PPComm once at startup"""
cls._pp_comm = PPComm(mapping)
def __getitem__(self, key: Union[str, slice]):
if isinstance(key, str):
if key == 'hidden_states':
return self.hidden_states
elif key == 'residual':
return self.residual
raise KeyError(f"Unknown key: {key}")
elif isinstance(key, slice):
return PipelineInterface(hidden_states=self.hidden_states[key] if
self.hidden_states is not None else None,
residual=self.residual[key]
if self.residual is not None else None)
def __setitem__(self, key: Union[str, slice], value: torch.Tensor):
if isinstance(key, str):
if key == 'hidden_states':
self.hidden_states = value
elif key == 'residual':
self.residual = value
else:
raise KeyError(f"Unknown key: {key}")
elif isinstance(key, slice):
if self.hidden_states is not None:
self.hidden_states[key] = value
if self.residual is not None:
self.residual[key] = value
def __iter__(self):
return iter((self.hidden_states, self.residual))
def recv(self):
"""Receive tensors from previous rank."""
if self.hidden_states is not None:
self._pp_comm.recv(self.hidden_states, tag=self.tag)
if self.residual is not None:
self._pp_comm.recv(self.residual, tag=self.tag)
def send(self):
"""Send tensors to next rank."""
if self.hidden_states is not None:
self._pp_comm.send(self.hidden_states, tag=self.tag)
if self.residual is not None:
self._pp_comm.send(self.residual, tag=self.tag)