fix: change pp broadcast pattern for LPs (#4130)

Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
This commit is contained in:
Erin 2025-05-08 20:07:13 -07:00 committed by GitHub
parent 91bf5e6a8e
commit cdf5ae1547
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1165,20 +1165,32 @@ class PyExecutor:
req_id)
@nvtx_range("_broadcast_new_requests")
def _broadcast_new_requests(self, new_requests):
def _broadcast_new_requests(
self,
new_requests: List[ExecutorRequest],
py_request_objects: tuple[str, dict] = None
) -> tuple[List[ExecutorRequest], Optional[tuple[str, dict]]]:
"""Broadcasts new_requests and optional Python-only metadata (`py_request_objects`) across pipeline stages.
`py_request_objects` is a tuple of (attribute_name, {request_id: object}).
"""
payloads = (new_requests, py_request_objects
) if py_request_objects is not None else new_requests
if not self.dist.has_pp:
return self.dist.broadcast(new_requests, root=0)
result = self.dist.broadcast(payloads, root=0)
return result if isinstance(result, tuple) else (result, None)
# broadcast within first tp group before send/recv chain to other tp groups
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
new_requests = self.dist.tp_broadcast(new_requests, root=0)
payloads = self.dist.tp_broadcast(payloads, root=0)
# tag = [0, num_micro_batches - 1] used for new_tokens send/recv
tag = self.num_micro_batches
# 1. send metadata: len(num_requests) and serialized buffer size
new_requests = payloads[0] if isinstance(payloads, tuple) else payloads
if self.dist.is_first_pp_rank and len(new_requests) > 0:
buf = np.array(bytearray(dill.dumps(new_requests)))
buf = np.array(bytearray(dill.dumps(payloads)))
buf_size = len(buf)
else:
buf, buf_size = None, 0
@ -1202,10 +1214,15 @@ class PyExecutor:
self.dist.send(buf, self.dist.next_pp_rank, tag)
if not self.dist.is_first_pp_rank:
new_requests = dill.loads(buf.tobytes()) # nosec B301
buf_data = dill.loads(buf.tobytes()) # nosec B301
if isinstance(buf_data, tuple):
new_requests, py_request_objects = buf_data
else:
new_requests = buf_data
assert len(new_requests) == num_new_requests
return new_requests
return new_requests, py_request_objects
@nvtx_range("_fetch_new_requests")
def _fetch_new_requests(self):
@ -1230,14 +1247,13 @@ class PyExecutor:
else:
py_request_objects = None
if self.dist.rank == 0 and not self.dist.has_pp:
if self.dist.rank == 0:
# Preserve original `new_requests` on rank 0 since it may contain
# Python-only objects (e.g., custom logits processors) not serializable by pybind.
_ = self._broadcast_new_requests(new_requests)
_ = self._broadcast_new_requests(new_requests, py_request_objects)
else:
new_requests = self._broadcast_new_requests(new_requests)
py_request_objects = self.dist.broadcast(py_request_objects, root=0)
new_requests, py_request_objects = self._broadcast_new_requests(
new_requests, py_request_objects)
if py_request_objects and (self.dist.tp_size > 1
or self.dist.has_pp) and self.dist.rank > 0: