[https://nvbugs/5472947][fix] wait on isend handles before reusing buffers (#7462)

Signed-off-by: Anurag Mukkara <134339030+amukkara@users.noreply.github.com>
This commit is contained in:
Anurag Mukkara 2025-09-03 13:20:02 +05:30 committed by GitHub
parent 79d93f9419
commit ae5136831f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -859,8 +859,7 @@ class PyExecutor:
# Send tokens to next pp rank (w.r.t model forward direction)
# Second last rank does not need to since last rank has original decoded tokens
if not self.dist.is_second_last_pp_rank:
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].wait()
self.wait_on_pp_send_handles(prev_microbatch_id)
self.send_handles[
prev_microbatch_id] = self.dist.isend_object(
sample_state.host,
@ -892,6 +891,8 @@ class PyExecutor:
self.resource_manager.update_resources(
previous_scheduled_batch)
self._remove_inflight_ids(previous_scheduled_batch)
self.wait_on_pp_send_handles(prev_microbatch_id)
self.micro_batches[prev_microbatch_id] = None
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
@ -905,6 +906,11 @@ class PyExecutor:
self.active_requests,
previous_batch)
def wait_on_pp_send_handles(self, microbatch_id):
if self.send_handles[microbatch_id] is not None:
self.send_handles[microbatch_id].wait()
self.send_handles[microbatch_id] = None
def _prepare_and_schedule_batch(self):
new_requests = self._fetch_and_activate_new_requests()
if self.should_stop_processing:
@ -1821,8 +1827,7 @@ class PyExecutor:
req.py_result = py_result
elif self.dist.is_last_pp_rank and len(finished_reqs):
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].wait()
self.wait_on_pp_send_handles(prev_microbatch_id)
self.send_handles[prev_microbatch_id] = self.dist.isend_object(
[r.py_result for r in finished_reqs],
dest=self.dist.next_pp_rank,