mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-10560][fix] Fix the time of pause() for overlap scheduler (#10943)
Signed-off-by: Tailing Yuan <yuantailing@gmail.com>
This commit is contained in:
parent
4a206351bb
commit
5553391c5e
@ -244,7 +244,7 @@ class PyExecutor:
|
||||
max_num_sequences: int,
|
||||
drafter: Optional[Drafter] = None,
|
||||
disable_overlap_scheduler: bool = False,
|
||||
max_input_len: int = 2048,
|
||||
max_input_len: int = 0x7fffffff,
|
||||
max_batch_size: int = 8,
|
||||
max_beam_width: int = 1,
|
||||
max_draft_len: int = 0,
|
||||
@ -1503,6 +1503,7 @@ class PyExecutor:
|
||||
if scheduled_batch is None:
|
||||
break
|
||||
|
||||
self._terminate_requests(scheduled_batch.paused_requests)
|
||||
self._pause_requests(scheduled_batch.paused_requests)
|
||||
|
||||
finished_requests = []
|
||||
@ -1722,7 +1723,7 @@ class PyExecutor:
|
||||
else:
|
||||
can_forward = True
|
||||
|
||||
self._pause_requests(scheduled_batch.paused_requests)
|
||||
self._terminate_requests(scheduled_batch.paused_requests)
|
||||
|
||||
can_queue, can_queue_this_rank = self._can_queue(
|
||||
scheduled_batch)
|
||||
@ -1819,6 +1820,8 @@ class PyExecutor:
|
||||
# Cleanup previous draft resources used in the draft model
|
||||
self.drafter.cleanup_previous_draft_resources()
|
||||
|
||||
self._pause_requests(scheduled_batch.paused_requests)
|
||||
|
||||
if can_queue:
|
||||
guided_decoder_failed_requests = None
|
||||
if self.guided_decoder is not None:
|
||||
@ -2871,14 +2874,16 @@ class PyExecutor:
|
||||
self.responses.pop(id)
|
||||
return response
|
||||
|
||||
def _pause_requests(self, requests_to_pause):
|
||||
def _terminate_requests(self, requests_to_pause):
|
||||
# todo: support work with self.inflight_req_ids.
|
||||
# Currently, self.inflight_req_ids is not.
|
||||
max_input_len = self.max_input_len
|
||||
for req in requests_to_pause:
|
||||
req.pause(max_input_len)
|
||||
self._terminate_request(req)
|
||||
|
||||
def _pause_requests(self, requests_to_pause):
|
||||
for req in requests_to_pause:
|
||||
req.pause(self.max_input_len)
|
||||
|
||||
def _add_inflight_ids(self, scheduled_requests):
|
||||
"""Add request IDs of current requests to self.inflight_req_ids.
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user