diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 419627e978..5eb12721c4 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -859,8 +859,7 @@ class PyExecutor: # Send tokens to next pp rank (w.r.t model forward direction) # Second last rank does not need to since last rank has original decoded tokens if not self.dist.is_second_last_pp_rank: - if self.send_handles[prev_microbatch_id] is not None: - self.send_handles[prev_microbatch_id].wait() + self.wait_on_pp_send_handles(prev_microbatch_id) self.send_handles[ prev_microbatch_id] = self.dist.isend_object( sample_state.host, @@ -892,6 +891,8 @@ class PyExecutor: self.resource_manager.update_resources( previous_scheduled_batch) self._remove_inflight_ids(previous_scheduled_batch) + + self.wait_on_pp_send_handles(prev_microbatch_id) self.micro_batches[prev_microbatch_id] = None if self.kv_cache_transceiver and self.ctx_in_transmission_requests: @@ -905,6 +906,11 @@ class PyExecutor: self.active_requests, previous_batch) + def wait_on_pp_send_handles(self, microbatch_id): + if self.send_handles[microbatch_id] is not None: + self.send_handles[microbatch_id].wait() + self.send_handles[microbatch_id] = None + def _prepare_and_schedule_batch(self): new_requests = self._fetch_and_activate_new_requests() if self.should_stop_processing: @@ -1821,8 +1827,7 @@ class PyExecutor: req.py_result = py_result elif self.dist.is_last_pp_rank and len(finished_reqs): - if self.send_handles[prev_microbatch_id] is not None: - self.send_handles[prev_microbatch_id].wait() + self.wait_on_pp_send_handles(prev_microbatch_id) self.send_handles[prev_microbatch_id] = self.dist.isend_object( [r.py_result for r in finished_reqs], dest=self.dist.next_pp_rank,