diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 26d3f27080..b322044434 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -412,9 +412,10 @@ def create_py_executor_instance(dist, lora_config.lora_target_modules, lora_config.trtllm_modules_to_hf_modules) - num_micro_batches = 1 - if mapping.has_pp: - num_micro_batches = mapping.pp_size + pytorch_backend_config.enable_overlap_scheduler + if mapping.has_pp(): + num_micro_batches = mapping.pp_size + else: + num_micro_batches = 2 if pytorch_backend_config.enable_overlap_scheduler else 1 resources["seq_slot_manager"] = SeqSlotManager( executor_config.max_batch_size * num_micro_batches) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 0fd5b70ee4..72d5b60d30 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -234,7 +234,7 @@ class PyExecutor: self.num_scheduled_requests: int = 0 # list of requests in each PP micro batch - self.num_micro_batches = self.dist.pp_size + enable_overlap_scheduler + self.num_micro_batches = self.dist.pp_size self.micro_batches: List[BatchStatePP | None] = [None] * self.num_micro_batches self.send_handles = [None] * self.num_micro_batches @@ -256,7 +256,7 @@ class PyExecutor: self.kv_cache_transceiver = kv_cache_transceiver if self.dist.pp_size > 1: - self.event_loop = self._executor_loop_pp_overlap if enable_overlap_scheduler else self._executor_loop_pp + self.event_loop = self._executor_loop_pp else: self.event_loop = self._executor_loop_overlap if enable_overlap_scheduler else self._executor_loop @@ -649,110 +649,6 @@ class PyExecutor: self.shutdown_event.set() def _executor_loop_pp(self): - torch.cuda.set_device(self.device_id) - got_finish_signal = False - num_dummy_request = 0 - microbatch_id = 0 - with self._profiler() as profile_step: - iter_start_time = time.time() - iter_stats = None - while not got_finish_signal or len(self.active_requests) > 0: - profile_step() - if self.enable_iter_perf_stats: - iter_start_time = time.time() - new_requests = self._fetch_new_requests() - got_finish_signal = self._merge_requests( - new_requests) or got_finish_signal - if got_finish_signal and len(self.active_requests) == 0: - break - - if self.enable_iter_perf_stats: - iter_stats = self._get_init_iter_stats( - len(new_requests), - self.new_active_requests_queue_latency_ms) - - if not got_finish_signal: - num_dummy_request = self._get_num_dummy_request() - if num_dummy_request > 0: - self._merge_dummy_request(num_dummy_request) - scheduled_batch, _, _ = self._schedule() - - self.num_scheduled_requests = scheduled_batch.batch_size - logger.debug( - f'has {len(self.active_requests)} active_request, ' - f'scheduled {len(scheduled_batch.context_requests)} context requests and ' - f'{len(scheduled_batch.generation_requests)} generation requests' - ) - - if self.enable_attention_dp: - tp_batch_sizes = self.dist.tp_allgather( - scheduled_batch.batch_size) - can_queue = 0 not in tp_batch_sizes - else: - can_queue = scheduled_batch.batch_size > 0 - if not can_queue: - assert len(self.inflight_req_ids) > 0, ( - "fail to schedule any pending request, probably run out of resource" - ) - - if not can_queue: - self.micro_batches[microbatch_id] = None - else: - # TODO: add pause_requests together with inflight_req_ids and handle draft_tokens - self._add_inflight_ids(scheduled_batch) - self.resource_manager.prepare_resources(scheduled_batch) - - # Stage 1: Forward + (decoding) pass ([should be] async) - if self.dist.is_last_pp_rank: - decoder_state = self._forward_step_last_pp( - scheduled_batch, microbatch_id) - else: - decoder_state = self._forward_step_inter_pp( - scheduled_batch) - - if self.enable_iter_perf_stats: - iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ - 'num_ctx_tokens'] - - batch_state = BatchStatePP( - decoder_state=decoder_state, - iter_start_time=iter_start_time, - iter_stats=iter_stats, - microbatch_id=microbatch_id, - ) - - if num_dummy_request > 0: - self._finish_dummy_request(scheduled_batch) - self.micro_batches[microbatch_id] = batch_state - - # Stage 2: Handle previous batch that only processed forward_step - # marching forward in the microbatch slots - prev_microbatch_id = (microbatch_id + - 1) % self.num_micro_batches - previous_batch = self.micro_batches[prev_microbatch_id] - finished_requests = [] - if previous_batch is not None: - if not self.dist.is_last_pp_rank: - self._handle_previous_batch_inter_pp(previous_batch) - - self._update_requests(previous_batch.decoder_state) - self._handle_cancelled_requests() - finished_requests = self._handle_responses() - previous_scheduled_batch = previous_batch.decoder_state.scheduled_requests - self.resource_manager.update_resources( - previous_scheduled_batch) - self._remove_inflight_ids(previous_scheduled_batch) - - microbatch_id = prev_microbatch_id - self._gather_dp_requests_num() - - if self.enable_iter_perf_stats and previous_batch is not None: - self._process_iter_stats(finished_requests, - self.active_requests, - previous_batch) - self._executor_loop_cleanup() - - def _executor_loop_pp_overlap(self): torch.cuda.set_device(self.device_id) got_finish_signal = False num_dummy_request = 0 @@ -1189,46 +1085,6 @@ class PyExecutor: scheduled_requests=scheduled_batch, new_tensors_host={"new_tokens_host": new_tokens_host}) - @nvtx_range("_forward_step_last_pp") - def _forward_step_last_pp(self, scheduled_batch, - microbatch_id) -> DecoderState: - batch_outputs = self._forward_step(scheduled_batch) - decoder_state = self._decode_async(scheduled_batch, batch_outputs) - self._update_request_states(scheduled_batch) - - if self.send_handles[microbatch_id] is not None: - self.send_handles[microbatch_id].Wait() - decoder_state.decoder_event.synchronize() - - self.send_handles[microbatch_id] = self.dist.isend_tensor_list( - decoder_state.new_tensors_host.values(), - dest=self.dist.next_pp_rank, - tag=microbatch_id) - - return decoder_state - - @nvtx_range("_handle_previous_batch_inter_pp") - def _handle_previous_batch_inter_pp( - self, previous_batch_state: BatchStatePP) -> None: - new_tokens_host = previous_batch_state.decoder_state.new_tensors_host - prev_microbatch_id = previous_batch_state.microbatch_id - # Receive tokens from prev pp rank w.r.t model forward direction - self.dist.recv_tensor_list( - new_tokens_host.values(), - src=self.dist.prev_pp_rank, - tag=prev_microbatch_id # not necessary and may discard - ) - - # Send tokens to next pp rank w.r.t model forward direction - # Second last rank not need since last rank has original decoded tokens - if not self.dist.is_second_last_pp_rank: - if self.send_handles[prev_microbatch_id] is not None: - self.send_handles[prev_microbatch_id].Wait() - self.send_handles[prev_microbatch_id] = self.dist.isend_tensor_list( - new_tokens_host.values(), - dest=self.dist.next_pp_rank, - tag=prev_microbatch_id) - def _update_new_active_requests_queue_latency(self, new_requests): if self.enable_iter_perf_stats and self.dist.rank == 0: now = time.time()