[None][fix] Avoid Double update for previous batch (#9888)

Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
Yi Zhang 2026-01-23 02:15:06 +08:00 committed by GitHub
parent 944c304bbb
commit d43be7b65e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 32 additions and 15 deletions

View File

@ -170,6 +170,7 @@ void initBindings(nb::module_& m)
.def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams)
.def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
.def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
.def_prop_ro("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState)
.def_prop_ro("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)
.def_prop_ro("is_context_finished", &GenLlmReq::isContextFinished)
.def_prop_ro("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)

View File

@ -175,6 +175,7 @@ void initBindings(pybind11::module_& m)
.def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams)
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
.def_property_readonly("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState)
.def_property_readonly("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)
.def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished)
.def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)

View File

@ -1161,7 +1161,7 @@ class PyExecutor:
f'{len(scheduled_batch.generation_requests)} generation requests'
)
can_queue = self._can_queue(scheduled_batch)
can_queue, _ = self._can_queue(scheduled_batch)
if not can_queue:
logger.debug(
f"microbatch {microbatch_id} cannot be queued, skipping"
@ -1359,13 +1359,17 @@ class PyExecutor:
def _can_queue(self, scheduled_batch):
# can_queue_this_rank is for case that the batch is not empty on this rank, but empty on other ranks
# For bs == 1, we cannot pad dummy request to make the batch non-empty since it will cause the batch size to be 2.
# 1 for dummy request, 1 for the to complete but haven't updated request.
if self.enable_attention_dp:
tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size)
can_queue = 0 not in tp_batch_sizes
can_queue_this_rank = scheduled_batch.batch_size > 0
else:
can_queue = scheduled_batch.batch_size > 0
can_queue = can_queue_this_rank = scheduled_batch.batch_size > 0
return can_queue
return can_queue, can_queue_this_rank
def _prepare_and_schedule_batch(self):
new_requests = self._fetch_and_activate_new_requests()
@ -1494,7 +1498,7 @@ class PyExecutor:
finished_requests = []
can_queue = self._can_queue(scheduled_batch)
can_queue, _ = self._can_queue(scheduled_batch)
if can_queue:
if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
@ -1509,7 +1513,7 @@ class PyExecutor:
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
if self.kv_connector_manager:
can_queue = self._can_queue(scheduled_batch)
can_queue, _ = self._can_queue(scheduled_batch)
if can_queue:
# init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers.
@ -1711,7 +1715,8 @@ class PyExecutor:
self._pause_requests(scheduled_batch.paused_requests)
can_queue = self._can_queue(scheduled_batch)
can_queue, can_queue_this_rank = self._can_queue(
scheduled_batch)
if can_queue:
if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
@ -1741,8 +1746,13 @@ class PyExecutor:
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
if self.kv_connector_manager:
can_queue = self._can_queue(scheduled_batch)
can_queue, can_queue_this_rank = self._can_queue(
scheduled_batch)
# If the batch is not empty on this rank, but empty on other ranks,
# we need to delay the update of the previous batch's sample state,
# and let the later iteration to update it.
should_process_previous_batch = can_queue or not can_queue_this_rank
if can_queue:
# The generation requests that are do not have batch_idx,
@ -1792,10 +1802,10 @@ class PyExecutor:
scheduled_batch, previous_tensors_device,
num_accepted_tokens_device)
if self.previous_batch is not None:
if self.previous_batch is not None and should_process_previous_batch:
self._update_requests(self.previous_batch.sample_state)
if self.drafter is not None and self.use_spec_decode:
if self.drafter is not None and self.use_spec_decode and should_process_previous_batch:
# Cleanup previous draft resources used in the draft model
self.drafter.cleanup_previous_draft_resources()
@ -1822,8 +1832,10 @@ class PyExecutor:
ctx_transmission_reqs = self._send_kv_async(
scheduled_batch.all_requests())
if self.previous_batch is not None:
if self.previous_batch is not None and should_process_previous_batch:
self._process_previous_batch()
else:
self._enqueue_responses([])
if can_queue:
if self.enable_iter_perf_stats:
@ -1835,6 +1847,9 @@ class PyExecutor:
iter_start_time=iter_start_time,
iter_stats=iter_stats,
ctx_transmission_reqs=ctx_transmission_reqs)
elif not can_queue_this_rank:
# If the batch is empty on this rank, we need to clear the previous batch.
self.previous_batch = None
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
):
@ -2194,10 +2209,10 @@ class PyExecutor:
if self.kv_cache_transceiver is None:
num_active_request = len(self.active_requests)
else:
num_active_request = sum([
0 if req.is_disagg_generation_init_state
or req.is_disagg_generation_transmission_in_progress else 1
for req in self.active_requests
num_active_request = len([
req for req in self.active_requests
if not (req.is_disagg_generation_init_state
or req.is_disagg_generation_transmission_in_progress)
])
if self.expected_num_active_requests - num_active_request > 0 and num_active_request == 0:
@ -2393,7 +2408,7 @@ class PyExecutor:
def _forward_step(
self,
scheduled_requests,
scheduled_requests: ScheduledRequests,
new_tensors_device: Optional[SampleStateTensors] = None,
num_accepted_tokens_device: Optional[torch.Tensor] = None):
ExpertStatistic.set_iter(self.iter_counter)