mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
[None][fix] Avoid Double update for previous batch (#9888)
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
This commit is contained in:
parent
944c304bbb
commit
d43be7b65e
@ -170,6 +170,7 @@ void initBindings(nb::module_& m)
|
||||
.def_prop_ro("context_phase_params", &GenLlmReq::getContextPhaseParams)
|
||||
.def_prop_ro("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
|
||||
.def_prop_ro("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
|
||||
.def_prop_ro("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState)
|
||||
.def_prop_ro("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)
|
||||
.def_prop_ro("is_context_finished", &GenLlmReq::isContextFinished)
|
||||
.def_prop_ro("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)
|
||||
|
||||
@ -175,6 +175,7 @@ void initBindings(pybind11::module_& m)
|
||||
.def_property_readonly("context_phase_params", &GenLlmReq::getContextPhaseParams)
|
||||
.def_property_readonly("is_context_only_request", &GenLlmReq::isContextOnlyRequest)
|
||||
.def_property_readonly("is_generation_only_request", &GenLlmReq::isGenerationOnlyRequest)
|
||||
.def_property_readonly("is_generation_to_complete_state", &GenLlmReq::isGenerationToCompleteState)
|
||||
.def_property_readonly("is_generation_complete_state", &GenLlmReq::isGenerationCompleteState)
|
||||
.def_property_readonly("is_context_finished", &GenLlmReq::isContextFinished)
|
||||
.def_property_readonly("is_disagg_generation_init_state", &GenLlmReq::isDisaggGenerationInitState)
|
||||
|
||||
@ -1161,7 +1161,7 @@ class PyExecutor:
|
||||
f'{len(scheduled_batch.generation_requests)} generation requests'
|
||||
)
|
||||
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
can_queue, _ = self._can_queue(scheduled_batch)
|
||||
if not can_queue:
|
||||
logger.debug(
|
||||
f"microbatch {microbatch_id} cannot be queued, skipping"
|
||||
@ -1359,13 +1359,17 @@ class PyExecutor:
|
||||
|
||||
def _can_queue(self, scheduled_batch):
|
||||
|
||||
# can_queue_this_rank is for case that the batch is not empty on this rank, but empty on other ranks
|
||||
# For bs == 1, we cannot pad dummy request to make the batch non-empty since it will cause the batch size to be 2.
|
||||
# 1 for dummy request, 1 for the to complete but haven't updated request.
|
||||
if self.enable_attention_dp:
|
||||
tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size)
|
||||
can_queue = 0 not in tp_batch_sizes
|
||||
can_queue_this_rank = scheduled_batch.batch_size > 0
|
||||
else:
|
||||
can_queue = scheduled_batch.batch_size > 0
|
||||
can_queue = can_queue_this_rank = scheduled_batch.batch_size > 0
|
||||
|
||||
return can_queue
|
||||
return can_queue, can_queue_this_rank
|
||||
|
||||
def _prepare_and_schedule_batch(self):
|
||||
new_requests = self._fetch_and_activate_new_requests()
|
||||
@ -1494,7 +1498,7 @@ class PyExecutor:
|
||||
|
||||
finished_requests = []
|
||||
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
can_queue, _ = self._can_queue(scheduled_batch)
|
||||
if can_queue:
|
||||
if self.kv_cache_transceiver:
|
||||
# For generation requests which have completed KV cache transfer
|
||||
@ -1509,7 +1513,7 @@ class PyExecutor:
|
||||
|
||||
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
|
||||
if self.kv_connector_manager:
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
can_queue, _ = self._can_queue(scheduled_batch)
|
||||
|
||||
if can_queue:
|
||||
# init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers.
|
||||
@ -1711,7 +1715,8 @@ class PyExecutor:
|
||||
|
||||
self._pause_requests(scheduled_batch.paused_requests)
|
||||
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
can_queue, can_queue_this_rank = self._can_queue(
|
||||
scheduled_batch)
|
||||
if can_queue:
|
||||
if self.kv_cache_transceiver:
|
||||
# For generation requests which have completed KV cache transfer
|
||||
@ -1741,8 +1746,13 @@ class PyExecutor:
|
||||
|
||||
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
|
||||
if self.kv_connector_manager:
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
can_queue, can_queue_this_rank = self._can_queue(
|
||||
scheduled_batch)
|
||||
|
||||
# If the batch is not empty on this rank, but empty on other ranks,
|
||||
# we need to delay the update of the previous batch's sample state,
|
||||
# and let the later iteration to update it.
|
||||
should_process_previous_batch = can_queue or not can_queue_this_rank
|
||||
if can_queue:
|
||||
|
||||
# The generation requests that are do not have batch_idx,
|
||||
@ -1792,10 +1802,10 @@ class PyExecutor:
|
||||
scheduled_batch, previous_tensors_device,
|
||||
num_accepted_tokens_device)
|
||||
|
||||
if self.previous_batch is not None:
|
||||
if self.previous_batch is not None and should_process_previous_batch:
|
||||
self._update_requests(self.previous_batch.sample_state)
|
||||
|
||||
if self.drafter is not None and self.use_spec_decode:
|
||||
if self.drafter is not None and self.use_spec_decode and should_process_previous_batch:
|
||||
# Cleanup previous draft resources used in the draft model
|
||||
self.drafter.cleanup_previous_draft_resources()
|
||||
|
||||
@ -1822,8 +1832,10 @@ class PyExecutor:
|
||||
ctx_transmission_reqs = self._send_kv_async(
|
||||
scheduled_batch.all_requests())
|
||||
|
||||
if self.previous_batch is not None:
|
||||
if self.previous_batch is not None and should_process_previous_batch:
|
||||
self._process_previous_batch()
|
||||
else:
|
||||
self._enqueue_responses([])
|
||||
|
||||
if can_queue:
|
||||
if self.enable_iter_perf_stats:
|
||||
@ -1835,6 +1847,9 @@ class PyExecutor:
|
||||
iter_start_time=iter_start_time,
|
||||
iter_stats=iter_stats,
|
||||
ctx_transmission_reqs=ctx_transmission_reqs)
|
||||
elif not can_queue_this_rank:
|
||||
# If the batch is empty on this rank, we need to clear the previous batch.
|
||||
self.previous_batch = None
|
||||
|
||||
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
|
||||
):
|
||||
@ -2194,10 +2209,10 @@ class PyExecutor:
|
||||
if self.kv_cache_transceiver is None:
|
||||
num_active_request = len(self.active_requests)
|
||||
else:
|
||||
num_active_request = sum([
|
||||
0 if req.is_disagg_generation_init_state
|
||||
or req.is_disagg_generation_transmission_in_progress else 1
|
||||
for req in self.active_requests
|
||||
num_active_request = len([
|
||||
req for req in self.active_requests
|
||||
if not (req.is_disagg_generation_init_state
|
||||
or req.is_disagg_generation_transmission_in_progress)
|
||||
])
|
||||
|
||||
if self.expected_num_active_requests - num_active_request > 0 and num_active_request == 0:
|
||||
@ -2393,7 +2408,7 @@ class PyExecutor:
|
||||
|
||||
def _forward_step(
|
||||
self,
|
||||
scheduled_requests,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
new_tensors_device: Optional[SampleStateTensors] = None,
|
||||
num_accepted_tokens_device: Optional[torch.Tensor] = None):
|
||||
ExpertStatistic.set_iter(self.iter_counter)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user