fix: Merge PP overlap and non-overlap executor loop (#3878)

Signed-off-by: Anurag Mukkara <134339030+amukkara@users.noreply.github.com>
This commit is contained in:
Anurag Mukkara 2025-05-13 15:04:36 -07:00 committed by GitHub
parent f408de2d99
commit b0a03a289c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 149 deletions

View File

@ -412,9 +412,10 @@ def create_py_executor_instance(dist,
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules)
num_micro_batches = 1
if mapping.has_pp:
num_micro_batches = mapping.pp_size + pytorch_backend_config.enable_overlap_scheduler
if mapping.has_pp():
num_micro_batches = mapping.pp_size
else:
num_micro_batches = 2 if pytorch_backend_config.enable_overlap_scheduler else 1
resources["seq_slot_manager"] = SeqSlotManager(
executor_config.max_batch_size * num_micro_batches)

View File

@ -234,7 +234,7 @@ class PyExecutor:
self.num_scheduled_requests: int = 0
# list of requests in each PP micro batch
self.num_micro_batches = self.dist.pp_size + enable_overlap_scheduler
self.num_micro_batches = self.dist.pp_size
self.micro_batches: List[BatchStatePP
| None] = [None] * self.num_micro_batches
self.send_handles = [None] * self.num_micro_batches
@ -256,7 +256,7 @@ class PyExecutor:
self.kv_cache_transceiver = kv_cache_transceiver
if self.dist.pp_size > 1:
self.event_loop = self._executor_loop_pp_overlap if enable_overlap_scheduler else self._executor_loop_pp
self.event_loop = self._executor_loop_pp
else:
self.event_loop = self._executor_loop_overlap if enable_overlap_scheduler else self._executor_loop
@ -649,110 +649,6 @@ class PyExecutor:
self.shutdown_event.set()
def _executor_loop_pp(self):
torch.cuda.set_device(self.device_id)
got_finish_signal = False
num_dummy_request = 0
microbatch_id = 0
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
while not got_finish_signal or len(self.active_requests) > 0:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
got_finish_signal = self._merge_requests(
new_requests) or got_finish_signal
if got_finish_signal and len(self.active_requests) == 0:
break
if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
self.new_active_requests_queue_latency_ms)
if not got_finish_signal:
num_dummy_request = self._get_num_dummy_request()
if num_dummy_request > 0:
self._merge_dummy_request(num_dummy_request)
scheduled_batch, _, _ = self._schedule()
self.num_scheduled_requests = scheduled_batch.batch_size
logger.debug(
f'has {len(self.active_requests)} active_request, '
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
f'{len(scheduled_batch.generation_requests)} generation requests'
)
if self.enable_attention_dp:
tp_batch_sizes = self.dist.tp_allgather(
scheduled_batch.batch_size)
can_queue = 0 not in tp_batch_sizes
else:
can_queue = scheduled_batch.batch_size > 0
if not can_queue:
assert len(self.inflight_req_ids) > 0, (
"fail to schedule any pending request, probably run out of resource"
)
if not can_queue:
self.micro_batches[microbatch_id] = None
else:
# TODO: add pause_requests together with inflight_req_ids and handle draft_tokens
self._add_inflight_ids(scheduled_batch)
self.resource_manager.prepare_resources(scheduled_batch)
# Stage 1: Forward + (decoding) pass ([should be] async)
if self.dist.is_last_pp_rank:
decoder_state = self._forward_step_last_pp(
scheduled_batch, microbatch_id)
else:
decoder_state = self._forward_step_inter_pp(
scheduled_batch)
if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
'num_ctx_tokens']
batch_state = BatchStatePP(
decoder_state=decoder_state,
iter_start_time=iter_start_time,
iter_stats=iter_stats,
microbatch_id=microbatch_id,
)
if num_dummy_request > 0:
self._finish_dummy_request(scheduled_batch)
self.micro_batches[microbatch_id] = batch_state
# Stage 2: Handle previous batch that only processed forward_step
# marching forward in the microbatch slots
prev_microbatch_id = (microbatch_id +
1) % self.num_micro_batches
previous_batch = self.micro_batches[prev_microbatch_id]
finished_requests = []
if previous_batch is not None:
if not self.dist.is_last_pp_rank:
self._handle_previous_batch_inter_pp(previous_batch)
self._update_requests(previous_batch.decoder_state)
self._handle_cancelled_requests()
finished_requests = self._handle_responses()
previous_scheduled_batch = previous_batch.decoder_state.scheduled_requests
self.resource_manager.update_resources(
previous_scheduled_batch)
self._remove_inflight_ids(previous_scheduled_batch)
microbatch_id = prev_microbatch_id
self._gather_dp_requests_num()
if self.enable_iter_perf_stats and previous_batch is not None:
self._process_iter_stats(finished_requests,
self.active_requests,
previous_batch)
self._executor_loop_cleanup()
def _executor_loop_pp_overlap(self):
torch.cuda.set_device(self.device_id)
got_finish_signal = False
num_dummy_request = 0
@ -1189,46 +1085,6 @@ class PyExecutor:
scheduled_requests=scheduled_batch,
new_tensors_host={"new_tokens_host": new_tokens_host})
@nvtx_range("_forward_step_last_pp")
def _forward_step_last_pp(self, scheduled_batch,
microbatch_id) -> DecoderState:
batch_outputs = self._forward_step(scheduled_batch)
decoder_state = self._decode_async(scheduled_batch, batch_outputs)
self._update_request_states(scheduled_batch)
if self.send_handles[microbatch_id] is not None:
self.send_handles[microbatch_id].Wait()
decoder_state.decoder_event.synchronize()
self.send_handles[microbatch_id] = self.dist.isend_tensor_list(
decoder_state.new_tensors_host.values(),
dest=self.dist.next_pp_rank,
tag=microbatch_id)
return decoder_state
@nvtx_range("_handle_previous_batch_inter_pp")
def _handle_previous_batch_inter_pp(
self, previous_batch_state: BatchStatePP) -> None:
new_tokens_host = previous_batch_state.decoder_state.new_tensors_host
prev_microbatch_id = previous_batch_state.microbatch_id
# Receive tokens from prev pp rank w.r.t model forward direction
self.dist.recv_tensor_list(
new_tokens_host.values(),
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id # not necessary and may discard
)
# Send tokens to next pp rank w.r.t model forward direction
# Second last rank not need since last rank has original decoded tokens
if not self.dist.is_second_last_pp_rank:
if self.send_handles[prev_microbatch_id] is not None:
self.send_handles[prev_microbatch_id].Wait()
self.send_handles[prev_microbatch_id] = self.dist.isend_tensor_list(
new_tokens_host.values(),
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)
def _update_new_active_requests_queue_latency(self, new_requests):
if self.enable_iter_perf_stats and self.dist.rank == 0:
now = time.time()