mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
refine broadcast new requests method (#3198)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
parent
a5f32f46fd
commit
8fe2e5865e
@ -1025,10 +1025,7 @@ class PyExecutor:
|
||||
self.request_queue, timeout,
|
||||
self.max_num_active_requests - len(self.active_requests))
|
||||
|
||||
if self.dist.has_pp:
|
||||
new_requests = self._broadcast_new_requests_pp(new_requests)
|
||||
else:
|
||||
new_requests = self.dist.broadcast(new_requests, root=0)
|
||||
new_requests = self._broadcast_new_requests(new_requests)
|
||||
|
||||
if self.enable_iter_perf_stats and self.dist.rank == 0:
|
||||
now = time.time()
|
||||
@ -1041,8 +1038,11 @@ class PyExecutor:
|
||||
|
||||
return new_requests
|
||||
|
||||
@nvtx_range("_broadcast_new_requests_pp")
|
||||
def _broadcast_new_requests_pp(self, new_requests):
|
||||
@nvtx_range("_broadcast_new_requests")
|
||||
def _broadcast_new_requests(self, new_requests):
|
||||
if not self.dist.has_pp:
|
||||
return self.dist.broadcast(new_requests, root=0)
|
||||
|
||||
# broadcast within first tp group before send/recv chain to other tp groups
|
||||
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
|
||||
new_requests = self.dist.tp_broadcast(new_requests, root=0)
|
||||
@ -1099,10 +1099,7 @@ class PyExecutor:
|
||||
self.request_queue, timeout,
|
||||
total_max_num_active_requests - total_num_active_requests)
|
||||
|
||||
if self.dist.has_pp:
|
||||
new_requests = self._broadcast_new_requests_pp(new_requests)
|
||||
else:
|
||||
new_requests = self.dist.broadcast(new_requests, root=0)
|
||||
new_requests = self._broadcast_new_requests(new_requests)
|
||||
|
||||
num_new_requests_all_ranks = len(new_requests)
|
||||
self.expected_num_active_requests = max(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user