mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: Merge PP overlap and non-overlap executor loop (#3878)
Signed-off-by: Anurag Mukkara <134339030+amukkara@users.noreply.github.com>
This commit is contained in:
parent
f408de2d99
commit
b0a03a289c
@ -412,9 +412,10 @@ def create_py_executor_instance(dist,
|
||||
lora_config.lora_target_modules,
|
||||
lora_config.trtllm_modules_to_hf_modules)
|
||||
|
||||
num_micro_batches = 1
|
||||
if mapping.has_pp:
|
||||
num_micro_batches = mapping.pp_size + pytorch_backend_config.enable_overlap_scheduler
|
||||
if mapping.has_pp():
|
||||
num_micro_batches = mapping.pp_size
|
||||
else:
|
||||
num_micro_batches = 2 if pytorch_backend_config.enable_overlap_scheduler else 1
|
||||
|
||||
resources["seq_slot_manager"] = SeqSlotManager(
|
||||
executor_config.max_batch_size * num_micro_batches)
|
||||
|
||||
@ -234,7 +234,7 @@ class PyExecutor:
|
||||
self.num_scheduled_requests: int = 0
|
||||
|
||||
# list of requests in each PP micro batch
|
||||
self.num_micro_batches = self.dist.pp_size + enable_overlap_scheduler
|
||||
self.num_micro_batches = self.dist.pp_size
|
||||
self.micro_batches: List[BatchStatePP
|
||||
| None] = [None] * self.num_micro_batches
|
||||
self.send_handles = [None] * self.num_micro_batches
|
||||
@ -256,7 +256,7 @@ class PyExecutor:
|
||||
|
||||
self.kv_cache_transceiver = kv_cache_transceiver
|
||||
if self.dist.pp_size > 1:
|
||||
self.event_loop = self._executor_loop_pp_overlap if enable_overlap_scheduler else self._executor_loop_pp
|
||||
self.event_loop = self._executor_loop_pp
|
||||
else:
|
||||
self.event_loop = self._executor_loop_overlap if enable_overlap_scheduler else self._executor_loop
|
||||
|
||||
@ -649,110 +649,6 @@ class PyExecutor:
|
||||
self.shutdown_event.set()
|
||||
|
||||
def _executor_loop_pp(self):
|
||||
torch.cuda.set_device(self.device_id)
|
||||
got_finish_signal = False
|
||||
num_dummy_request = 0
|
||||
microbatch_id = 0
|
||||
with self._profiler() as profile_step:
|
||||
iter_start_time = time.time()
|
||||
iter_stats = None
|
||||
while not got_finish_signal or len(self.active_requests) > 0:
|
||||
profile_step()
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_start_time = time.time()
|
||||
new_requests = self._fetch_new_requests()
|
||||
got_finish_signal = self._merge_requests(
|
||||
new_requests) or got_finish_signal
|
||||
if got_finish_signal and len(self.active_requests) == 0:
|
||||
break
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats = self._get_init_iter_stats(
|
||||
len(new_requests),
|
||||
self.new_active_requests_queue_latency_ms)
|
||||
|
||||
if not got_finish_signal:
|
||||
num_dummy_request = self._get_num_dummy_request()
|
||||
if num_dummy_request > 0:
|
||||
self._merge_dummy_request(num_dummy_request)
|
||||
scheduled_batch, _, _ = self._schedule()
|
||||
|
||||
self.num_scheduled_requests = scheduled_batch.batch_size
|
||||
logger.debug(
|
||||
f'has {len(self.active_requests)} active_request, '
|
||||
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
|
||||
f'{len(scheduled_batch.generation_requests)} generation requests'
|
||||
)
|
||||
|
||||
if self.enable_attention_dp:
|
||||
tp_batch_sizes = self.dist.tp_allgather(
|
||||
scheduled_batch.batch_size)
|
||||
can_queue = 0 not in tp_batch_sizes
|
||||
else:
|
||||
can_queue = scheduled_batch.batch_size > 0
|
||||
if not can_queue:
|
||||
assert len(self.inflight_req_ids) > 0, (
|
||||
"fail to schedule any pending request, probably run out of resource"
|
||||
)
|
||||
|
||||
if not can_queue:
|
||||
self.micro_batches[microbatch_id] = None
|
||||
else:
|
||||
# TODO: add pause_requests together with inflight_req_ids and handle draft_tokens
|
||||
self._add_inflight_ids(scheduled_batch)
|
||||
self.resource_manager.prepare_resources(scheduled_batch)
|
||||
|
||||
# Stage 1: Forward + (decoding) pass ([should be] async)
|
||||
if self.dist.is_last_pp_rank:
|
||||
decoder_state = self._forward_step_last_pp(
|
||||
scheduled_batch, microbatch_id)
|
||||
else:
|
||||
decoder_state = self._forward_step_inter_pp(
|
||||
scheduled_batch)
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
|
||||
'num_ctx_tokens']
|
||||
|
||||
batch_state = BatchStatePP(
|
||||
decoder_state=decoder_state,
|
||||
iter_start_time=iter_start_time,
|
||||
iter_stats=iter_stats,
|
||||
microbatch_id=microbatch_id,
|
||||
)
|
||||
|
||||
if num_dummy_request > 0:
|
||||
self._finish_dummy_request(scheduled_batch)
|
||||
self.micro_batches[microbatch_id] = batch_state
|
||||
|
||||
# Stage 2: Handle previous batch that only processed forward_step
|
||||
# marching forward in the microbatch slots
|
||||
prev_microbatch_id = (microbatch_id +
|
||||
1) % self.num_micro_batches
|
||||
previous_batch = self.micro_batches[prev_microbatch_id]
|
||||
finished_requests = []
|
||||
if previous_batch is not None:
|
||||
if not self.dist.is_last_pp_rank:
|
||||
self._handle_previous_batch_inter_pp(previous_batch)
|
||||
|
||||
self._update_requests(previous_batch.decoder_state)
|
||||
self._handle_cancelled_requests()
|
||||
finished_requests = self._handle_responses()
|
||||
previous_scheduled_batch = previous_batch.decoder_state.scheduled_requests
|
||||
self.resource_manager.update_resources(
|
||||
previous_scheduled_batch)
|
||||
self._remove_inflight_ids(previous_scheduled_batch)
|
||||
|
||||
microbatch_id = prev_microbatch_id
|
||||
self._gather_dp_requests_num()
|
||||
|
||||
if self.enable_iter_perf_stats and previous_batch is not None:
|
||||
self._process_iter_stats(finished_requests,
|
||||
self.active_requests,
|
||||
previous_batch)
|
||||
self._executor_loop_cleanup()
|
||||
|
||||
def _executor_loop_pp_overlap(self):
|
||||
torch.cuda.set_device(self.device_id)
|
||||
got_finish_signal = False
|
||||
num_dummy_request = 0
|
||||
@ -1189,46 +1085,6 @@ class PyExecutor:
|
||||
scheduled_requests=scheduled_batch,
|
||||
new_tensors_host={"new_tokens_host": new_tokens_host})
|
||||
|
||||
@nvtx_range("_forward_step_last_pp")
|
||||
def _forward_step_last_pp(self, scheduled_batch,
|
||||
microbatch_id) -> DecoderState:
|
||||
batch_outputs = self._forward_step(scheduled_batch)
|
||||
decoder_state = self._decode_async(scheduled_batch, batch_outputs)
|
||||
self._update_request_states(scheduled_batch)
|
||||
|
||||
if self.send_handles[microbatch_id] is not None:
|
||||
self.send_handles[microbatch_id].Wait()
|
||||
decoder_state.decoder_event.synchronize()
|
||||
|
||||
self.send_handles[microbatch_id] = self.dist.isend_tensor_list(
|
||||
decoder_state.new_tensors_host.values(),
|
||||
dest=self.dist.next_pp_rank,
|
||||
tag=microbatch_id)
|
||||
|
||||
return decoder_state
|
||||
|
||||
@nvtx_range("_handle_previous_batch_inter_pp")
|
||||
def _handle_previous_batch_inter_pp(
|
||||
self, previous_batch_state: BatchStatePP) -> None:
|
||||
new_tokens_host = previous_batch_state.decoder_state.new_tensors_host
|
||||
prev_microbatch_id = previous_batch_state.microbatch_id
|
||||
# Receive tokens from prev pp rank w.r.t model forward direction
|
||||
self.dist.recv_tensor_list(
|
||||
new_tokens_host.values(),
|
||||
src=self.dist.prev_pp_rank,
|
||||
tag=prev_microbatch_id # not necessary and may discard
|
||||
)
|
||||
|
||||
# Send tokens to next pp rank w.r.t model forward direction
|
||||
# Second last rank not need since last rank has original decoded tokens
|
||||
if not self.dist.is_second_last_pp_rank:
|
||||
if self.send_handles[prev_microbatch_id] is not None:
|
||||
self.send_handles[prev_microbatch_id].Wait()
|
||||
self.send_handles[prev_microbatch_id] = self.dist.isend_tensor_list(
|
||||
new_tokens_host.values(),
|
||||
dest=self.dist.next_pp_rank,
|
||||
tag=prev_microbatch_id)
|
||||
|
||||
def _update_new_active_requests_queue_latency(self, new_requests):
|
||||
if self.enable_iter_perf_stats and self.dist.rank == 0:
|
||||
now = time.time()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user