mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
78 lines
2.8 KiB
Python
78 lines
2.8 KiB
Python
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
from .distributed import PPComm
|
|
|
|
|
|
class PipelineInterface:
|
|
"""
|
|
A container class for passing intermediate tensors between pipeline parallel ranks.
|
|
|
|
It contains two intermediate tensors: [hidden_states, residual], supporting:
|
|
- Dict access: pp['hidden_states'], pp['residual']
|
|
- Unpacking: hidden, residual = pp
|
|
- PP communication: pp.send(), pp.recv()
|
|
- Slicing: pp[start:end]
|
|
|
|
Note: When using this interface in pp, the packing/unpacking and send/recv
|
|
operations must be used symmetrically within stage and between succsive ranks.
|
|
"""
|
|
_pp_comm = None
|
|
|
|
def __init__(self,
|
|
hidden_states: Optional[torch.Tensor] = None,
|
|
residual: Optional[torch.Tensor] = None):
|
|
self.hidden_states = hidden_states
|
|
self.residual = residual
|
|
self.tag = 1234
|
|
|
|
@classmethod
|
|
def init_pp_comm(cls, mapping):
|
|
"""Initialize PPComm once at startup"""
|
|
cls._pp_comm = PPComm(mapping)
|
|
|
|
def __getitem__(self, key: Union[str, slice]):
|
|
if isinstance(key, str):
|
|
if key == 'hidden_states':
|
|
return self.hidden_states
|
|
elif key == 'residual':
|
|
return self.residual
|
|
raise KeyError(f"Unknown key: {key}")
|
|
elif isinstance(key, slice):
|
|
return PipelineInterface(hidden_states=self.hidden_states[key] if
|
|
self.hidden_states is not None else None,
|
|
residual=self.residual[key]
|
|
if self.residual is not None else None)
|
|
|
|
def __setitem__(self, key: Union[str, slice], value: torch.Tensor):
|
|
if isinstance(key, str):
|
|
if key == 'hidden_states':
|
|
self.hidden_states = value
|
|
elif key == 'residual':
|
|
self.residual = value
|
|
else:
|
|
raise KeyError(f"Unknown key: {key}")
|
|
elif isinstance(key, slice):
|
|
if self.hidden_states is not None:
|
|
self.hidden_states[key] = value
|
|
if self.residual is not None:
|
|
self.residual[key] = value
|
|
|
|
def __iter__(self):
|
|
return iter((self.hidden_states, self.residual))
|
|
|
|
def recv(self):
|
|
"""Receive tensors from previous rank."""
|
|
if self.hidden_states is not None:
|
|
self._pp_comm.recv(self.hidden_states, tag=self.tag)
|
|
if self.residual is not None:
|
|
self._pp_comm.recv(self.residual, tag=self.tag)
|
|
|
|
def send(self):
|
|
"""Send tensors to next rank."""
|
|
if self.hidden_states is not None:
|
|
self._pp_comm.send(self.hidden_states, tag=self.tag)
|
|
if self.residual is not None:
|
|
self._pp_comm.send(self.residual, tag=self.tag)
|