From d43be7b65e79ca9e22fccd877785ed3c7cdbd85c Mon Sep 17 00:00:00 2001 From: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Date: Fri, 23 Jan 2026 02:15:06 +0800 Subject: [PATCH] [None][fix] Avoid Double update for previous batch (#9888) Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> --- .../nanobind/batch_manager/bindings.cpp | 1 + .../pybind/batch_manager/bindings.cpp | 1 + tensorrt_llm/_torch/pyexecutor/py_executor.py | 45 ++++++++++++------- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index 50ccb5c284..72a94944d5 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -170,6 +170,7 @@ void initBindings(nb::module_& m) .def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams) .def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest) .def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest) + .def_prop_ro("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState) .def_prop_ro("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState) .def_prop_ro("is_context_finished", &GenLlmReq::isContextFinished) .def_prop_ro("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 5c29f61e2e..657b7c6f36 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -175,6 +175,7 @@ void initBindings(pybind11::module_& m) .def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams) .def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest) .def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest) + .def_property_readonly("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState) .def_property_readonly("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState) .def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished) .def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index aef8d34941..39365a8714 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1161,7 +1161,7 @@ class PyExecutor: f'{len(scheduled_batch.generation_requests)} generation requests' ) - can_queue = self._can_queue(scheduled_batch) + can_queue, _ = self._can_queue(scheduled_batch) if not can_queue: logger.debug( f"microbatch {microbatch_id} cannot be queued, skipping" @@ -1359,13 +1359,17 @@ class PyExecutor: def _can_queue(self, scheduled_batch): + # can_queue_this_rank is for case that the batch is not empty on this rank, but empty on other ranks + # For bs == 1, we cannot pad dummy request to make the batch non-empty since it will cause the batch size to be 2. + # 1 for dummy request, 1 for the to complete but haven't updated request. if self.enable_attention_dp: tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size) can_queue = 0 not in tp_batch_sizes + can_queue_this_rank = scheduled_batch.batch_size > 0 else: - can_queue = scheduled_batch.batch_size > 0 + can_queue = can_queue_this_rank = scheduled_batch.batch_size > 0 - return can_queue + return can_queue, can_queue_this_rank def _prepare_and_schedule_batch(self): new_requests = self._fetch_and_activate_new_requests() @@ -1494,7 +1498,7 @@ class PyExecutor: finished_requests = [] - can_queue = self._can_queue(scheduled_batch) + can_queue, _ = self._can_queue(scheduled_batch) if can_queue: if self.kv_cache_transceiver: # For generation requests which have completed KV cache transfer @@ -1509,7 +1513,7 @@ class PyExecutor: # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed if self.kv_connector_manager: - can_queue = self._can_queue(scheduled_batch) + can_queue, _ = self._can_queue(scheduled_batch) if can_queue: # init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers. @@ -1711,7 +1715,8 @@ class PyExecutor: self._pause_requests(scheduled_batch.paused_requests) - can_queue = self._can_queue(scheduled_batch) + can_queue, can_queue_this_rank = self._can_queue( + scheduled_batch) if can_queue: if self.kv_cache_transceiver: # For generation requests which have completed KV cache transfer @@ -1741,8 +1746,13 @@ class PyExecutor: # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed if self.kv_connector_manager: - can_queue = self._can_queue(scheduled_batch) + can_queue, can_queue_this_rank = self._can_queue( + scheduled_batch) + # If the batch is not empty on this rank, but empty on other ranks, + # we need to delay the update of the previous batch's sample state, + # and let the later iteration to update it. + should_process_previous_batch = can_queue or not can_queue_this_rank if can_queue: # The generation requests that are do not have batch_idx, @@ -1792,10 +1802,10 @@ class PyExecutor: scheduled_batch, previous_tensors_device, num_accepted_tokens_device) - if self.previous_batch is not None: + if self.previous_batch is not None and should_process_previous_batch: self._update_requests(self.previous_batch.sample_state) - if self.drafter is not None and self.use_spec_decode: + if self.drafter is not None and self.use_spec_decode and should_process_previous_batch: # Cleanup previous draft resources used in the draft model self.drafter.cleanup_previous_draft_resources() @@ -1822,8 +1832,10 @@ class PyExecutor: ctx_transmission_reqs = self._send_kv_async( scheduled_batch.all_requests()) - if self.previous_batch is not None: + if self.previous_batch is not None and should_process_previous_batch: self._process_previous_batch() + else: + self._enqueue_responses([]) if can_queue: if self.enable_iter_perf_stats: @@ -1835,6 +1847,9 @@ class PyExecutor: iter_start_time=iter_start_time, iter_stats=iter_stats, ctx_transmission_reqs=ctx_transmission_reqs) + elif not can_queue_this_rank: + # If the batch is empty on this rank, we need to clear the previous batch. + self.previous_batch = None if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests( ): @@ -2194,10 +2209,10 @@ class PyExecutor: if self.kv_cache_transceiver is None: num_active_request = len(self.active_requests) else: - num_active_request = sum([ - 0 if req.is_disagg_generation_init_state - or req.is_disagg_generation_transmission_in_progress else 1 - for req in self.active_requests + num_active_request = len([ + req for req in self.active_requests + if not (req.is_disagg_generation_init_state + or req.is_disagg_generation_transmission_in_progress) ]) if self.expected_num_active_requests - num_active_request > 0 and num_active_request == 0: @@ -2393,7 +2408,7 @@ class PyExecutor: def _forward_step( self, - scheduled_requests, + scheduled_requests: ScheduledRequests, new_tensors_device: Optional[SampleStateTensors] = None, num_accepted_tokens_device: Optional[torch.Tensor] = None): ExpertStatistic.set_iter(self.iter_counter)