mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: change pp broadcast pattern for LPs (#4130)
Signed-off-by: Erin Ho <14718778+hchings@users.noreply.github.com>
This commit is contained in:
parent
91bf5e6a8e
commit
cdf5ae1547
@ -1165,20 +1165,32 @@ class PyExecutor:
|
||||
req_id)
|
||||
|
||||
@nvtx_range("_broadcast_new_requests")
|
||||
def _broadcast_new_requests(self, new_requests):
|
||||
def _broadcast_new_requests(
|
||||
self,
|
||||
new_requests: List[ExecutorRequest],
|
||||
py_request_objects: tuple[str, dict] = None
|
||||
) -> tuple[List[ExecutorRequest], Optional[tuple[str, dict]]]:
|
||||
"""Broadcasts new_requests and optional Python-only metadata (`py_request_objects`) across pipeline stages.
|
||||
`py_request_objects` is a tuple of (attribute_name, {request_id: object}).
|
||||
"""
|
||||
payloads = (new_requests, py_request_objects
|
||||
) if py_request_objects is not None else new_requests
|
||||
|
||||
if not self.dist.has_pp:
|
||||
return self.dist.broadcast(new_requests, root=0)
|
||||
result = self.dist.broadcast(payloads, root=0)
|
||||
return result if isinstance(result, tuple) else (result, None)
|
||||
|
||||
# broadcast within first tp group before send/recv chain to other tp groups
|
||||
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
|
||||
new_requests = self.dist.tp_broadcast(new_requests, root=0)
|
||||
payloads = self.dist.tp_broadcast(payloads, root=0)
|
||||
|
||||
# tag = [0, num_micro_batches - 1] used for new_tokens send/recv
|
||||
tag = self.num_micro_batches
|
||||
|
||||
# 1. send metadata: len(num_requests) and serialized buffer size
|
||||
new_requests = payloads[0] if isinstance(payloads, tuple) else payloads
|
||||
if self.dist.is_first_pp_rank and len(new_requests) > 0:
|
||||
buf = np.array(bytearray(dill.dumps(new_requests)))
|
||||
buf = np.array(bytearray(dill.dumps(payloads)))
|
||||
buf_size = len(buf)
|
||||
else:
|
||||
buf, buf_size = None, 0
|
||||
@ -1202,10 +1214,15 @@ class PyExecutor:
|
||||
self.dist.send(buf, self.dist.next_pp_rank, tag)
|
||||
|
||||
if not self.dist.is_first_pp_rank:
|
||||
new_requests = dill.loads(buf.tobytes()) # nosec B301
|
||||
buf_data = dill.loads(buf.tobytes()) # nosec B301
|
||||
if isinstance(buf_data, tuple):
|
||||
new_requests, py_request_objects = buf_data
|
||||
else:
|
||||
new_requests = buf_data
|
||||
|
||||
assert len(new_requests) == num_new_requests
|
||||
|
||||
return new_requests
|
||||
return new_requests, py_request_objects
|
||||
|
||||
@nvtx_range("_fetch_new_requests")
|
||||
def _fetch_new_requests(self):
|
||||
@ -1230,14 +1247,13 @@ class PyExecutor:
|
||||
else:
|
||||
py_request_objects = None
|
||||
|
||||
if self.dist.rank == 0 and not self.dist.has_pp:
|
||||
if self.dist.rank == 0:
|
||||
# Preserve original `new_requests` on rank 0 since it may contain
|
||||
# Python-only objects (e.g., custom logits processors) not serializable by pybind.
|
||||
_ = self._broadcast_new_requests(new_requests)
|
||||
_ = self._broadcast_new_requests(new_requests, py_request_objects)
|
||||
else:
|
||||
new_requests = self._broadcast_new_requests(new_requests)
|
||||
|
||||
py_request_objects = self.dist.broadcast(py_request_objects, root=0)
|
||||
new_requests, py_request_objects = self._broadcast_new_requests(
|
||||
new_requests, py_request_objects)
|
||||
|
||||
if py_request_objects and (self.dist.tp_size > 1
|
||||
or self.dist.has_pp) and self.dist.rank > 0:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user