diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 0d2ae0a0c3..89b6f97185 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1165,20 +1165,32 @@ class PyExecutor: req_id) @nvtx_range("_broadcast_new_requests") - def _broadcast_new_requests(self, new_requests): + def _broadcast_new_requests( + self, + new_requests: List[ExecutorRequest], + py_request_objects: tuple[str, dict] = None + ) -> tuple[List[ExecutorRequest], Optional[tuple[str, dict]]]: + """Broadcasts new_requests and optional Python-only metadata (`py_request_objects`) across pipeline stages. + `py_request_objects` is a tuple of (attribute_name, {request_id: object}). + """ + payloads = (new_requests, py_request_objects + ) if py_request_objects is not None else new_requests + if not self.dist.has_pp: - return self.dist.broadcast(new_requests, root=0) + result = self.dist.broadcast(payloads, root=0) + return result if isinstance(result, tuple) else (result, None) # broadcast within first tp group before send/recv chain to other tp groups if self.dist.tp_size > 1 and self.dist.is_first_pp_rank: - new_requests = self.dist.tp_broadcast(new_requests, root=0) + payloads = self.dist.tp_broadcast(payloads, root=0) # tag = [0, num_micro_batches - 1] used for new_tokens send/recv tag = self.num_micro_batches # 1. send metadata: len(num_requests) and serialized buffer size + new_requests = payloads[0] if isinstance(payloads, tuple) else payloads if self.dist.is_first_pp_rank and len(new_requests) > 0: - buf = np.array(bytearray(dill.dumps(new_requests))) + buf = np.array(bytearray(dill.dumps(payloads))) buf_size = len(buf) else: buf, buf_size = None, 0 @@ -1202,10 +1214,15 @@ class PyExecutor: self.dist.send(buf, self.dist.next_pp_rank, tag) if not self.dist.is_first_pp_rank: - new_requests = dill.loads(buf.tobytes()) # nosec B301 + buf_data = dill.loads(buf.tobytes()) # nosec B301 + if isinstance(buf_data, tuple): + new_requests, py_request_objects = buf_data + else: + new_requests = buf_data + assert len(new_requests) == num_new_requests - return new_requests + return new_requests, py_request_objects @nvtx_range("_fetch_new_requests") def _fetch_new_requests(self): @@ -1230,14 +1247,13 @@ class PyExecutor: else: py_request_objects = None - if self.dist.rank == 0 and not self.dist.has_pp: + if self.dist.rank == 0: # Preserve original `new_requests` on rank 0 since it may contain # Python-only objects (e.g., custom logits processors) not serializable by pybind. - _ = self._broadcast_new_requests(new_requests) + _ = self._broadcast_new_requests(new_requests, py_request_objects) else: - new_requests = self._broadcast_new_requests(new_requests) - - py_request_objects = self.dist.broadcast(py_request_objects, root=0) + new_requests, py_request_objects = self._broadcast_new_requests( + new_requests, py_request_objects) if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp) and self.dist.rank > 0: