refine broadcast new requests method (#3198)

Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
QI JUN 2025-04-02 08:05:20 +08:00 committed by GitHub
parent a5f32f46fd
commit 8fe2e5865e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1025,10 +1025,7 @@ class PyExecutor:
self.request_queue, timeout,
self.max_num_active_requests - len(self.active_requests))
if self.dist.has_pp:
new_requests = self._broadcast_new_requests_pp(new_requests)
else:
new_requests = self.dist.broadcast(new_requests, root=0)
new_requests = self._broadcast_new_requests(new_requests)
if self.enable_iter_perf_stats and self.dist.rank == 0:
now = time.time()
@ -1041,8 +1038,11 @@ class PyExecutor:
return new_requests
@nvtx_range("_broadcast_new_requests_pp")
def _broadcast_new_requests_pp(self, new_requests):
@nvtx_range("_broadcast_new_requests")
def _broadcast_new_requests(self, new_requests):
if not self.dist.has_pp:
return self.dist.broadcast(new_requests, root=0)
# broadcast within first tp group before send/recv chain to other tp groups
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
new_requests = self.dist.tp_broadcast(new_requests, root=0)
@ -1099,10 +1099,7 @@ class PyExecutor:
self.request_queue, timeout,
total_max_num_active_requests - total_num_active_requests)
if self.dist.has_pp:
new_requests = self._broadcast_new_requests_pp(new_requests)
else:
new_requests = self.dist.broadcast(new_requests, root=0)
new_requests = self._broadcast_new_requests(new_requests)
num_new_requests_all_ranks = len(new_requests)
self.expected_num_active_requests = max(