From 5f4df89109a92c2dc8be2225c68b6cdb940d6869 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Tue, 10 Feb 2026 15:43:28 +0800 Subject: [PATCH] [None][feat] Fully non-blocking pipeline parallelism executor loop. (#10349) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 77 ++- tensorrt_llm/_torch/pyexecutor/py_executor.py | 447 ++++++++++++------ .../_torch/pyexecutor/request_utils.py | 3 +- tensorrt_llm/_torch/utils.py | 26 +- tensorrt_llm/_utils.py | 11 + tensorrt_llm/llmapi/disagg_utils.py | 3 +- tensorrt_llm/llmapi/mpi_session.py | 8 +- tests/unittest/llmapi/test_llm.py | 19 +- tests/unittest/utils/util.py | 18 + 9 files changed, 450 insertions(+), 162 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 96522f3f9b..490db48b08 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1,5 +1,5 @@ from copy import copy, deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch @@ -232,6 +232,29 @@ class LogProbStorage: class PyResult: """PyResult reimplements some features of `bindings.executor.Result` in Python""" + @dataclass + class Diff: + """ + Diff is used to track the changes of the PyResult. + It is designed to incrementally sync the PyResult to other ranks + by `get_diff` on one rank and `apply_diff` on other ranks. + """ + exclude_last_generation_logits: bool | None = None + context_logits_list: list[torch.Tensor] = field(default_factory=list) + generation_logits_list: list[torch.Tensor] = field(default_factory=list) + reset_log_probs: tuple[list[TokenLogprobs], + list[float] | None] | None = None + log_probs_list: list[tuple[list[TokenLogprobs], list[float] + | None]] = field(default_factory=list) + mm_embeddings: dict[str, Any] | None = None + mrope_position_ids: dict[str, Any] | None = None + mrope_position_deltas: dict[str, Any] | None = None + additional_context_outputs_list: list[tuple[str, torch.Tensor]] = field( + default_factory=list) + additional_generation_outputs_list: list[tuple[str, + torch.Tensor]] = field( + default_factory=list) + def __init__(self, *, prompt_len: int, @@ -277,28 +300,72 @@ class PyResult: name: [] for name in additional_outputs } if additional_outputs else None + self.diff = PyResult.Diff() + + def reset_diff(self): + self.diff = PyResult.Diff() + + def get_diff(self) -> Diff: + for i, context_logits in enumerate(self.diff.context_logits_list): + self.diff.context_logits_list[i] = context_logits.to("cpu") + for i, generation_logits in enumerate(self.diff.generation_logits_list): + self.diff.generation_logits_list[i] = generation_logits.to("cpu") + return self.diff + + def apply_diff(self, diff: Diff): + if diff.exclude_last_generation_logits is not None: + self._exclude_last_generation_logits = diff.exclude_last_generation_logits + if len(diff.context_logits_list) > 0: + for context_logits in diff.context_logits_list: + self._context_logits.append(context_logits) + if len(diff.generation_logits_list) > 0: + for generation_logits in diff.generation_logits_list: + self._generation_logits.append(generation_logits) + if diff.reset_log_probs is not None: + self._log_probs.set_log_probs(*diff.reset_log_probs) + if len(diff.log_probs_list) > 0: + for log_probs, cum_log_probs in diff.log_probs_list: + self._log_probs.append(log_probs, cum_log_probs) + if diff.mm_embeddings is not None: + self._mm_embeddings = diff.mm_embeddings + if diff.mrope_position_ids is not None: + self._mrope_position_ids = diff.mrope_position_ids + self._mrope_position_deltas = diff.mrope_position_deltas + if len(diff.additional_context_outputs_list) > 0: + for name, additional_context_outputs in diff.additional_context_outputs_list: + self._additional_context_outputs[name].append( + additional_context_outputs) + if len(diff.additional_generation_outputs_list) > 0: + for name, additional_generation_outputs in diff.additional_generation_outputs_list: + self._additional_generation_outputs[name].append( + additional_generation_outputs) def set_exclude_last_generation_logits( self, exclude_last_generation_logits: bool): self._exclude_last_generation_logits = exclude_last_generation_logits + self.diff.exclude_last_generation_logits = exclude_last_generation_logits def append_context_logits(self, context_logits: torch.Tensor): if self._context_logits: self._context_logits.append(context_logits) + self.diff.context_logits_list.append(context_logits) def append_generation_logits(self, generation_logits: torch.Tensor): if self._generation_logits: self._generation_logits.append(generation_logits) + self.diff.generation_logits_list.append(generation_logits) def append_log_probs(self, log_probs: list[TokenLogprobs], cum_log_probs: Optional[list[float]] = None): if self._log_probs: self._log_probs.append(log_probs, cum_log_probs) + self.diff.log_probs_list.append((log_probs, cum_log_probs)) def append_mm_embeddings(self, mm_embeddings: torch.Tensor): self._mm_embeddings = SharedTensorContainer.from_tensor( mm_embeddings).dump_to_dict() + self.diff.mm_embeddings = self._mm_embeddings def set_mrope_position( self, @@ -309,6 +376,8 @@ class PyResult: mrope_position_ids).dump_to_dict()) self._mrope_position_deltas = (SharedTensorContainer.from_tensor( mrope_position_deltas).dump_to_dict()) + self.diff.mrope_position_ids = self._mrope_position_ids + self.diff.mrope_position_deltas = self._mrope_position_deltas def transfer_remaining_device_logits(self): """Finalize any remaining generation logits transfers (for chunked mode)""" @@ -319,11 +388,15 @@ class PyResult: self, name: str, additional_context_outputs: torch.Tensor): self._additional_context_outputs[name].append( additional_context_outputs.to("cpu", non_blocking=True)) + self.diff.additional_context_outputs_list.append( + (name, self._additional_context_outputs[name][-1])) def append_additional_generation_outputs( self, name: str, additional_generation_outputs: torch.Tensor): self._additional_generation_outputs[name].append( additional_generation_outputs.to("cpu", non_blocking=True)) + self.diff.additional_generation_outputs_list.append( + (name, self._additional_generation_outputs[name][-1])) def set_log_probs(self, log_probs: list[TokenLogprobs], cum_log_probs: list[float]): @@ -334,6 +407,8 @@ class PyResult: """ if self._log_probs: self._log_probs.set_log_probs(log_probs, cum_log_probs) + self.diff.reset_log_probs = (log_probs, cum_log_probs) + self.diff.log_probs_list.clear() @property def context_logits(self) -> torch.Tensor | None: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 1f03b86565..00bfeb3304 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -2,17 +2,17 @@ import dataclasses import datetime import functools import os -import pickle # nosec B403 import threading import time import traceback from collections import deque from contextlib import contextmanager +from enum import IntEnum +from queue import Queue from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch -from tensorrt_llm._torch.expert_statistic import ExpertStatistic from tensorrt_llm.llmapi import DisaggScheduleStyle from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds @@ -21,10 +21,9 @@ try: except ImportError: from cuda import cudart -from tensorrt_llm._torch.pyexecutor.resource_manager import ( - ResourceManagerType, request_context) from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled, - mpi_disabled, nvtx_range, trace_func) + mpi_comm, mpi_disabled, nvtx_range, + set_thread_local_mpi_comm, trace_func) from tensorrt_llm.bindings.executor import (DisServingRequestStats, FinishReason, InflightBatchingStats, IterationStats, KvCacheStats, @@ -40,6 +39,7 @@ from tensorrt_llm.runtime.generation import CUASSERT from tensorrt_llm.tools.layer_wise_benchmarks import get_calibrator from ..distributed import Distributed +from ..expert_statistic import ExpertStatistic from ..models.modeling_utils import DecoderModelForCausalLM from ..modules.decoder_layer import DecoderLayer from ..speculative.drafter import Drafter @@ -58,7 +58,8 @@ from .model_engine import ModelEngine from .request_utils import (RequestBroadcaster, attach_py_objects_to_requests, get_from_waiting_queue, merge_requests, schedule_attention_dp_requests) -from .resource_manager import ResourceManager +from .resource_manager import (ResourceManager, ResourceManagerType, + request_context) from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState, SampleStateTensors, TRTLLMSampler) from .scheduler import (RequestScheduler, ScheduledRequests, @@ -72,10 +73,15 @@ PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP" # Set to a path to save detailed tracing of PyTorch operations. PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE" -# Unique tag base to avoid collisions with token/logits comms -TERMINATION_COMM_TAG_BASE = 20000 -PP_COMM_TAG_SCHEDULE_RESULT = 21000 -PP_COMM_TAG_SAMPLE_STATE_BASE = 21001 + +class PPCommTag(IntEnum): + """ + Unique tags for pipeline parallelism communication. + """ + TERMINATION = 20000 + SCHEDULE_RESULT = 20001 + EXECUTED_BATCH_NUM = 20002 + SAMPLE_STATE = 20003 @functools.cache @@ -240,6 +246,12 @@ class AsyncTransferManager: class PyExecutor: + # Minimum number of async micro batches for async PP execution. + # This is a trade-off between memory usage and performance. + # If the number of micro batches is too small, the executor will spend too much time in synchronization. + # If the number of micro batches is too large, the executor will spend too much host memory (No additional GPU memory is required). + # 1024 in-flight micro batches can avoid synchronization in most cases and keep host memory usage low. + MIN_ASYNC_MICRO_BATCH_NUM = 1024 def __init__(self, resource_manager, @@ -371,18 +383,20 @@ class PyExecutor: os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0)) # list of requests in each PP micro batch - self.num_micro_batches = self.dist.pp_size + self.num_micro_batches = max(self.dist.pp_size, + self.MIN_ASYNC_MICRO_BATCH_NUM) self.micro_batches: List[BatchStatePP | None] = [None] * self.num_micro_batches self.send_handles = [None] * self.num_micro_batches # schedule handle for PP to propagate the first PP rank's schedule result - self.send_schedule_handler = None + self.send_schedule_handles = [None] * self.num_micro_batches + self.send_expected_batch_num_handles = [None] * self.num_micro_batches + self.unhandled_batch_counter = 0 self.pp_scheduler_max_retry_count = int( os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10)) self.pp_multi_stream_sample = os.environ.get( "TRTLLM_PP_MULTI_STREAM_SAMPLE", "1") == "1" self.sample_stream = torch.cuda.Stream() - self.start_sample_event = torch.cuda.Event() self.finish_sample_event = torch.cuda.Event() if (self.dist.pp_size > 1 and self.pp_multi_stream_sample and isinstance(self.sampler, TRTLLMSampler)): @@ -479,6 +493,13 @@ class PyExecutor: if self.dist.pp_size > 1: self.event_loop = self._executor_loop_pp + # `TLLM_PP_ASYNC_BROADCAST_SAMPLE_STATE` controls whether to broadcast the sample state asynchronously. + # If true, the executor loop can broadcast and handle sample states asynchronously to achieve best perf. + # If false, the executor loop can only broadcast and handle each sample state in a pre-defined iteration. + # It is only for debugging purposes. + # Some tests can disable it to get a deterministic behavior. + self.pp_async_broadcast_sample_state = os.environ.get( + "TLLM_PP_ASYNC_BROADCAST_SAMPLE_STATE", "1") == "1" else: self.event_loop = self._executor_loop if self.disable_overlap_scheduler else self._executor_loop_overlap if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): @@ -560,7 +581,22 @@ class PyExecutor: def start_worker(self): with self.worker_lock: - if self.worker_started == False: + if not self.worker_started: + if self.dist.pp_size > 1: + self.executed_batch_queue: Queue[BatchStatePP] = Queue( + maxsize=self.num_micro_batches) + self.executed_batch_response_queue: Queue[ + BatchStatePP] = Queue(maxsize=-1) + broadcast_sample_state_loop = self._broadcast_sample_state_loop + if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): + broadcast_sample_state_loop = trace_func( + broadcast_sample_state_loop) + self.broadcast_sample_state_handler = threading.Thread( + target=broadcast_sample_state_loop, + daemon=True, + name="broadcast_sample_state_handler", + ) + self.broadcast_sample_state_handler.start() self.worker_thread = threading.Thread( target=self._event_loop_wrapper, daemon=True) self.worker_thread.start() @@ -658,6 +694,9 @@ class PyExecutor: logger.error("Hang detected, shutting down immediately.") return self.worker_thread.join() + if self.dist.pp_size > 1: + self.executed_batch_queue.put(None) + self.broadcast_sample_state_handler.join() self.worker_started = False for manager in self.resource_manager.resource_managers.values(): if manager: @@ -1064,41 +1103,63 @@ class PyExecutor: def _executor_loop_cleanup(self): - for h in self.send_handles: - if h is not None: - h.wait() + for i in range(self.num_micro_batches): + self.wait_on_pp_send_handles(self.send_handles, i) + self.wait_on_pp_send_handles(self.send_schedule_handles, i) + self.wait_on_pp_send_handles(self.send_expected_batch_num_handles, + i) with self.response_cv: self.is_shutdown = True self.response_cv.notify_all() self.shutdown_event.set() - def _pp_schedule_and_propagate(self): + def _pp_schedule_and_propagate(self, microbatch_id: int): """The first PP rank schedules the requests and propagates the result to all other PP ranks.""" - # The first PP rank schedules the requests, other ranks receive the schedule result from the previous PP rank. - if self.dist.is_first_pp_rank: + # For TP/CP cases, the first rank schedules the requests. + # For DP cases, the first PP rank schedules the requests. + scheduled_batch = None + serializable_schedule = None + is_dp_broadcast = self.dist.tp_size > 1 and self.enable_attention_dp + if self.dist.rank == 0 or (self.dist.is_first_pp_rank + and is_dp_broadcast): scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( ) serializable_schedule = SerializableSchedulerOutput.from_scheduler_result( scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs) - else: + + # Broadcast within first tp+cp group before send/recv chain to other tp+cp groups + if self.dist.is_first_pp_rank: + if self.dist.tp_size > 1 and not self.enable_attention_dp: + with nvtx_range("tp_broadcast_schedule"): + serializable_schedule = self.dist.tp_broadcast( + serializable_schedule, root=0) + if self.dist.cp_size > 1: + with nvtx_range("cp_broadcast_schedule"): + serializable_schedule = self.dist.cp_broadcast( + serializable_schedule, root=0) + + # Other ranks receive the schedule result from the previous PP rank. + if not self.dist.is_first_pp_rank: with nvtx_range("recv_schedule_from_prev_pp"): serializable_schedule = self.dist.recv_object( - self.dist.prev_pp_rank, PP_COMM_TAG_SCHEDULE_RESULT) - scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result( - self.active_requests) + self.dist.prev_pp_rank, PPCommTag.SCHEDULE_RESULT) # Propagate the schedule result to the next PP rank except the last PP rank. if not self.dist.is_last_pp_rank: - if self.send_schedule_handler is not None: - with nvtx_range("wait_send_schedule_handler"): - self.send_schedule_handler.wait() + self.wait_on_pp_send_handles(self.send_schedule_handles, + microbatch_id) with nvtx_range("send_schedule_to_next_pp"): - self.send_schedule_handler = self.dist.isend_object( - serializable_schedule, self.dist.next_pp_rank, - PP_COMM_TAG_SCHEDULE_RESULT) + self.send_schedule_handles[ + microbatch_id] = self.dist.isend_object( + serializable_schedule, self.dist.next_pp_rank, + PPCommTag.SCHEDULE_RESULT) + + if scheduled_batch is None: + scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result( + self.active_requests) return scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs def _pp_retry_until_can_schedule(self, scheduled_batch): @@ -1177,8 +1238,8 @@ class PyExecutor: # Stage 0: first PP rank schedules requests and propagates the result to all other PP ranks. scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._pp_schedule_and_propagate( - ) - if not self.dist.is_first_pp_rank: + microbatch_id) + if self.dist.rank != 0: # Retry until current rank can run first PP's schedule result. self._pp_retry_until_can_schedule(scheduled_batch) # Run scheduler locally because scheduler may change llm requests' state. @@ -1236,7 +1297,7 @@ class PyExecutor: # Return the first token to the client self._handle_first_token_response(scheduled_batch) - # Stage 1: Async forward (all ranks) and decoding pass (last rank only) + # Stage 1.1: Async forward (all ranks) and decoding pass (last rank only) if not self.dist.is_last_pp_rank: with torch.cuda.nvtx.range( f"_forward_step_inter_pp pp_rank {self.dist.pp_rank}" @@ -1269,9 +1330,9 @@ class PyExecutor: name: tensor.clone() for name, tensor in batch_outputs.items() } - self.start_sample_event.record() + self.sample_stream.wait_stream( + torch.cuda.current_stream()) with torch.cuda.stream(self.sample_stream): - self.start_sample_event.wait() sample_state = self._sample_async( scheduled_batch, batch_outputs_copy) self.finish_sample_event.record() @@ -1302,103 +1363,222 @@ class PyExecutor: self.micro_batches[microbatch_id] = batch_state - # sync sampler for previous microbatch to start new sample state comm chain. - prev_microbatch_id = (microbatch_id - - 1) % self.num_micro_batches - previous_batch = self.micro_batches[prev_microbatch_id] - if previous_batch is not None: + # Stage 1.2: Sync sampler for previous microbatch to start new sample state comm chain. + # For last PP rank, we must synchronize the previous batch + # since we need to broadcast its sample state soon afterwards in the same iteration. + # For other PP ranks, we can delay the synchronization if the current batch cannot be queued. + previous_batch = self.previous_batch + if can_queue: + self.previous_batch = batch_state + if (self.dist.is_last_pp_rank + or can_queue) and previous_batch is not None: with nvtx_range("sync_previous_sampler_event"): previous_batch.sample_state.sampler_event.synchronize() - # Stage 2: Communicate sample state for previous batch between ranks + # Stage 2: Enqueue sample state for executed batch to ring broadcast it in background thread asynchronously. # send/recv chain: (pp_size - 1) -> 0 -> 1 -> ... -> (pp_size - 2) # intermediate ranks: send/recv sample state for next microbatch to allow overlap - offset = -1 if self.dist.is_last_pp_rank else 1 - prev_microbatch_id = (microbatch_id + - offset) % self.num_micro_batches - previous_batch = self.micro_batches[prev_microbatch_id] - tag = PP_COMM_TAG_SAMPLE_STATE_BASE + prev_microbatch_id - if previous_batch is not None: - sample_state = previous_batch.sample_state - if not self.dist.is_last_pp_rank: - # Receive tokens from previous pp rank (w.r.t model forward direction) - with nvtx_range("recv_sample_state"): - sample_state.host = self.dist.recv_object( - src=self.dist.prev_pp_rank, - tag=tag, + offset = -1 if self.dist.is_last_pp_rank else ( + 1 - self.dist.pp_size) + executed_microbatch_id = (microbatch_id + + offset) % self.num_micro_batches + executed_batch = self.micro_batches[executed_microbatch_id] + if executed_batch is not None: + self.executed_batch_queue.put(executed_batch) + self.unhandled_batch_counter += 1 + self.micro_batches[executed_microbatch_id] = None + + def fetch_executed_batches() -> list[BatchStatePP]: + executed_batches = [] + if self.pp_async_broadcast_sample_state: + # Wait for at least one batch to finish if no new request is available. + must_get = not can_queue + else: + must_get = True + while not self.executed_batch_response_queue.empty() or ( + must_get and self.unhandled_batch_counter > 0): + with nvtx_range("get_executed_batch"): + executed_batches.append( + self.executed_batch_response_queue.get()) + must_get = False + return executed_batches + + def ring_broadcast_executed_batch_num( + executed_batch_num: int) -> int: + if self.dist.is_first_pp_rank and self.dist.tp_size * self.dist.cp_size > 1: + with nvtx_range("tp_cp_broadcast_executed_batch_num"): + executed_batch_num = self.dist.tp_cp_broadcast( + executed_batch_num, + root=0, ) - - # Send tokens to next pp rank (w.r.t model forward direction) - # Second last rank does not need to since last rank has original decoded tokens - if not self.dist.is_second_last_pp_rank: - self.wait_on_pp_send_handles(prev_microbatch_id) - with nvtx_range("send_sample_state"): - self.send_handles[ - prev_microbatch_id] = self.dist.isend_object( - sample_state.host, + if not self.dist.is_first_pp_rank: + with nvtx_range("recv_expected_batch_num"): + executed_batch_num = self.dist.recv_object( + src=self.dist.prev_pp_rank, + tag=PPCommTag.EXECUTED_BATCH_NUM, + ) + if not self.dist.is_last_pp_rank: + self.wait_on_pp_send_handles( + self.send_expected_batch_num_handles, microbatch_id) + with nvtx_range("send_expected_batch_num"): + self.send_expected_batch_num_handles[ + microbatch_id] = self.dist.isend_object( + executed_batch_num, dest=self.dist.next_pp_rank, - tag=tag) + tag=PPCommTag.EXECUTED_BATCH_NUM, + ) + return executed_batch_num - # Stage 3: Finalize previous batch that finished sample state communication - # In last pp rank, stage 2 and 3 process different previous batches - prev_microbatch_id = (microbatch_id + - 1) % self.num_micro_batches - previous_batch = self.micro_batches[prev_microbatch_id] - finished_requests = [] - if previous_batch is not None: - with torch.cuda.nvtx.range("_handle_previous_batch_pp"): - sample_state = previous_batch.sample_state - sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs - self._update_requests(previous_batch.sample_state) + def handle_executed_batches(executed_batch_num: int): + if self.dist.rank != 0: + dequeue_counter = 0 + while dequeue_counter < executed_batch_num: + with nvtx_range("get_executed_batch"): + executed_batch = self.executed_batch_response_queue.get( + ) + self._handle_executed_batch(executed_batch) + dequeue_counter += 1 + else: + for executed_batch in executed_batches: + self._handle_executed_batch(executed_batch) + self.unhandled_batch_counter -= executed_batch_num - if self.kv_cache_transceiver: - self._send_kv_async( - previous_batch.finished_ctx_reqs) - self._handle_canceled_requests() + executed_batch_num = 0 - self._handle_logits_communication( - previous_batch, prev_microbatch_id) + # Stage 3.1: The first rank determines the number of executed batches. + if self.dist.rank == 0: + executed_batches = fetch_executed_batches() + executed_batch_num = len(executed_batches) - finished_requests = self._handle_responses() - previous_scheduled_batch = previous_batch.sample_state.scheduled_requests - attn_metadata = getattr(self.model_engine, - 'attn_metadata', None) - kv_cache_dtype_byte_size = getattr( - self.model_engine, 'kv_cache_dtype_byte_size', None) - self.resource_manager.update_resources( - previous_scheduled_batch, attn_metadata, - kv_cache_dtype_byte_size) + # Stage 3.2: Broadcast the number of executed batches to other ranks. + executed_batch_num = ring_broadcast_executed_batch_num( + executed_batch_num) - self._remove_inflight_ids(previous_batch) + # Stage 3.3: Handle executed batches. + handle_executed_batches(executed_batch_num) - self.wait_on_pp_send_handles(prev_microbatch_id) - self.micro_batches[prev_microbatch_id] = None - - if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests( - ): - self._check_kv_transfer_timeout() - - if self._disagg_pp_termination_handler is not None: - self._disagg_pp_termination_handler.terminate_pending_requests( - ) - - # march forward in microbatch slots + # Stage 4: March forward in microbatch slots microbatch_id = (microbatch_id + 1) % self.num_micro_batches - - if self.enable_iter_perf_stats and previous_batch is not None: - sample_state = previous_batch.sample_state - sample_state.scheduled_requests.context_requests = previous_batch.scheduled_ctx_reqs - self._process_iter_stats(finished_requests, - self.active_requests, - previous_batch, microbatch_id) - self.iter_counter += 1 + # Stage 5: Handle remaining executed batches in the queue. + while self.unhandled_batch_counter > 0: + with nvtx_range("get_executed_batch"): + executed_batch = self.executed_batch_response_queue.get() + self._handle_executed_batch(executed_batch) + self.unhandled_batch_counter -= 1 + + def _broadcast_sample_state_loop(self): + logger.debug( + f"Starting broadcast sample state loop for pp_rank {self.dist.pp_rank}" + ) + torch.cuda.set_device(self.device_id) + # ensure the context is created, otherwise, some MPI calls will fail. + CUASSERT(cudart.cudaSetDevice(self.device_id)) + # Acquiring pkl5.Intracomm's send/recv locks from both executor loop thread + # and this thread will cause perf drop and even deadlock. + # We create new MPI comm to avoid these issues. + logger.info( + "Create new MPI comm for broadcast sample state thread to avoid deadlock." + ) + new_mpi_comm = mpi_comm().Dup() + set_thread_local_mpi_comm(new_mpi_comm) + while True: + executed_batch = self.executed_batch_queue.get() + if executed_batch is None: + break + self._ring_broadcast_sample_state(executed_batch) + set_thread_local_mpi_comm(None) + new_mpi_comm.Free() + + def _ring_broadcast_sample_state( + self, + executed_batch: Optional[BatchStatePP], + ) -> None: + if executed_batch is None: + return + + tag = PPCommTag.SAMPLE_STATE + microbatch_id = executed_batch.microbatch_id + sample_state = executed_batch.sample_state + requests = sample_state.scheduled_requests.all_requests() + + if not self.dist.is_last_pp_rank: + # Receive tokens from previous pp rank (w.r.t model forward direction) + with nvtx_range("recv_sample_state"): + sample_state.host, py_result_diffs = self.dist.recv_object( + src=self.dist.prev_pp_rank, + tag=tag, + ) + + for request, py_result_diff in zip(requests, py_result_diffs): + request.py_result.apply_diff(py_result_diff) + + self.executed_batch_response_queue.put(executed_batch) + + # Send tokens to next pp rank (w.r.t model forward direction) + # Second last rank does not need to since last rank has original decoded tokens + if not self.dist.is_second_last_pp_rank: + py_result_diffs = [] + for request in requests: + diff = request.py_result.get_diff() + py_result_diffs.append(diff) + request.py_result.reset_diff() + self.wait_on_pp_send_handles(self.send_handles, microbatch_id) + with nvtx_range("send_sample_state"): + self.send_handles[microbatch_id] = self.dist.isend_object( + (sample_state.host, py_result_diffs), + dest=self.dist.next_pp_rank, + tag=tag, + ) + + def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]): + finished_requests = [] + if executed_batch is not None: + with torch.cuda.nvtx.range("_handle_executed_batch_pp"): + sample_state = executed_batch.sample_state + sample_state.scheduled_requests.context_requests = executed_batch.finished_ctx_reqs + self._update_requests(executed_batch.sample_state) + + if self.kv_cache_transceiver: + self._send_kv_async(executed_batch.finished_ctx_reqs) + self._handle_canceled_requests() + + finished_requests = self._handle_responses() + previous_scheduled_batch = executed_batch.sample_state.scheduled_requests + attn_metadata = getattr(self.model_engine, 'attn_metadata', + None) + kv_cache_dtype_byte_size = getattr(self.model_engine, + 'kv_cache_dtype_byte_size', + None) + self.resource_manager.update_resources( + previous_scheduled_batch, attn_metadata, + kv_cache_dtype_byte_size) + + self._remove_inflight_ids(executed_batch) + + if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests( + ): + self._check_kv_transfer_timeout() + + if self._disagg_pp_termination_handler is not None: + self._disagg_pp_termination_handler.terminate_pending_requests() + + if self.enable_iter_perf_stats and executed_batch is not None: + sample_state = executed_batch.sample_state + sample_state.scheduled_requests.context_requests = executed_batch.scheduled_ctx_reqs + self._process_iter_stats( + finished_requests, + self.active_requests, + executed_batch, + executed_batch.microbatch_id % self.dist.pp_size, + ) + @nvtx_range("wait_on_pp_send_handles") - def wait_on_pp_send_handles(self, microbatch_id): - if self.send_handles[microbatch_id] is not None: - self.send_handles[microbatch_id].wait() - self.send_handles[microbatch_id] = None + def wait_on_pp_send_handles(self, send_handles, microbatch_id): + if send_handles[microbatch_id] is not None: + send_handles[microbatch_id].wait() + send_handles[microbatch_id] = None def _can_queue(self, scheduled_batch): @@ -3022,41 +3202,6 @@ class PyExecutor: self._terminate_request(request) return requests_to_terminate - def _handle_logits_communication(self, previous_batch, prev_microbatch_id): - """Handle logits communication between pipeline parallel ranks. - - If logits were requested, the last PP rank sends to the first PP rank (who sends responses) - the logits of the requests that have finished. - - Args: - previous_batch: The previous batch state - prev_microbatch_id: The microbatch ID for the previous batch - """ - # NOTE: If the rank processing the logits ever becomes the same as - # the rank sending the responses, this code can be removed. - finished_reqs = [ - r for r in - previous_batch.sample_state.scheduled_requests.all_requests() - if r.state == LlmRequestState.GENERATION_COMPLETE and ( - r.py_return_context_logits or r.py_return_generation_logits - or r.py_additional_outputs is not None) - ] - if self.dist.is_first_pp_rank and len(finished_reqs): - finished_reqs_py_results = [r.py_result for r in finished_reqs] - finished_reqs_py_results = self.dist.recv_object( - src=self.dist.prev_pp_rank, - tag=prev_microbatch_id, - ) - for req, py_result in zip(finished_reqs, finished_reqs_py_results): - req.py_result = py_result - - elif self.dist.is_last_pp_rank and len(finished_reqs): - self.wait_on_pp_send_handles(prev_microbatch_id) - self.send_handles[prev_microbatch_id] = self.dist.isend_object( - [r.py_result for r in finished_reqs], - dest=self.dist.next_pp_rank, - tag=prev_microbatch_id) - def _await_any_response(self, timeout: Optional[float] = None ) -> List[LlmResponse]: @@ -3198,7 +3343,7 @@ class DisaggPPTerminationHandler: self._pending_termination = {} self._terminating_iteration = 0 self._send_handle = None - self._comm_tag = TERMINATION_COMM_TAG_BASE + self._comm_tag = PPCommTag.TERMINATION def terminate(self, request: LlmRequest): self._pending_termination[request.py_request_id] = request diff --git a/tensorrt_llm/_torch/pyexecutor/request_utils.py b/tensorrt_llm/_torch/pyexecutor/request_utils.py index cae311a395..42018b982a 100644 --- a/tensorrt_llm/_torch/pyexecutor/request_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/request_utils.py @@ -673,7 +673,8 @@ class RequestBroadcaster: # Broadcast within first PP stage before send/recv chain to other PP stages. # This needs to cover both TP and CP ranks within the first PP stage. if self.dist.is_first_pp_rank: - payloads = self.dist.tp_cp_broadcast(payloads, root=0) + with nvtx_range("tp_broadcast_requests"): + payloads = self.dist.tp_cp_broadcast(payloads, root=0) # Tag for communication tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 876e4077a0..e182eee32c 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -8,7 +8,8 @@ from typing import Dict, List import torch from torch.nn import functional as F -from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor +from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor, + torch_dtype_to_str) from tensorrt_llm.mapping import Mapping from tensorrt_llm.math_utils import ceil_div, pad_up from tensorrt_llm.quantization.utils import fp4_utils @@ -416,6 +417,29 @@ def relu2(x: torch.Tensor) -> torch.Tensor: return torch.square(F.relu(x)) +def tensor_to_str(x: torch.Tensor, num_elements: int = 10) -> str: + # Pass num_elements=-1 will print the whole tensor + if num_elements < 0: + num_elements = torch.numel(x) + if x.dtype in (torch.int32, torch.int64): + float_x = x.to(dtype=float) + else: + float_x = x + return ("Tensor(" + f"shape={tuple(x.shape)}, " + f"dtype={torch_dtype_to_str(x.dtype)}, " + f"device={x.device}, " + f"stats=(" + f"abs_mean={float_x.abs().mean().item():.3f}, " + f"mean={float_x.mean().item():.3f}, " + f"std={float_x.std().item():.3f}, " + f"max={x.max().item():.3f}, " + f"min={x.min().item():.3f}" + "), " + f"values={x.flatten()[:num_elements].tolist()}" + ")") + + @maybe_compile def maybe_compiled_copy_(dst, src): dst.copy_(src) diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 8180487fa3..fb35c78d08 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -23,6 +23,7 @@ import socket import struct import sys import tempfile +import threading import trace import traceback import weakref @@ -509,7 +510,17 @@ def set_mpi_comm(new_comm): comm = new_comm +thread_local_comm = threading.local() + + +def set_thread_local_mpi_comm(new_comm): + thread_local_comm.value = new_comm + + def mpi_comm(): + if hasattr(thread_local_comm, + "value") and thread_local_comm.value is not None: + return thread_local_comm.value return comm diff --git a/tensorrt_llm/llmapi/disagg_utils.py b/tensorrt_llm/llmapi/disagg_utils.py index ae6b7135bf..8512771e1f 100644 --- a/tensorrt_llm/llmapi/disagg_utils.py +++ b/tensorrt_llm/llmapi/disagg_utils.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple import yaml from mpi4py.MPI import COMM_WORLD, Comm +from mpi4py.util import pkl5 from .._utils import global_mpi_rank, global_mpi_size @@ -327,7 +328,7 @@ def split_world_comm( f"global_rank: {global_rank}, instance_idx: {instance_idx}, sub_rank: {sub_rank}, is_leader: {is_leader}" ) - return is_leader, instance_idx, sub_comm + return is_leader, instance_idx, pkl5.Intracomm(sub_comm) def parse_metadata_server_config_file( diff --git a/tensorrt_llm/llmapi/mpi_session.py b/tensorrt_llm/llmapi/mpi_session.py index d32e5a7b7a..441d046d44 100644 --- a/tensorrt_llm/llmapi/mpi_session.py +++ b/tensorrt_llm/llmapi/mpi_session.py @@ -170,8 +170,14 @@ class MpiPoolSession(MpiSession): def _start_mpi_pool(self): assert not self.mpi_pool, 'MPI session already started' + env = { + key: value + for key, value in os.environ.items() + if key.startswith("TRTLLM") or key.startswith("TLLM") + } self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers, - path=sys.path) + path=sys.path, + env=env) def __del__(self): self.shutdown_abort() diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index cc120ab1a1..36f0cf63f2 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import datetime import gc import json @@ -57,7 +58,8 @@ from llmapi.lora_test_utils import ( check_llama_7b_multi_lora_from_request_test_harness, check_llama_7b_multi_unique_lora_adapters_from_request) from utils.llm_data import llm_models_root -from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_pre_hopper, skip_single_gpu +from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_pre_hopper, skip_single_gpu, altered_env + # isort: on # The unittests are based on the tiny-llama, which is fast to build and run. @@ -2171,11 +2173,16 @@ def llm_get_stats_test_harness(tp_size: int = 1, llm_args_extra["fast_build"] = True LLM_CLASS = LLM - with LLM_CLASS(model=llama_model_path, - kv_cache_config=global_kvcache_config, - tensor_parallel_size=tp_size, - pipeline_parallel_size=pp_size, - **llm_args_extra) as llm: + # Since we need to check pp's internal states, we disable the async broadcast + # to get a deterministic behavior. + env_ctx = altered_env(TLLM_PP_ASYNC_BROADCAST_SAMPLE_STATE="0") \ + if pp_size > 1 else contextlib.nullcontext() + + with env_ctx, LLM_CLASS(model=llama_model_path, + kv_cache_config=global_kvcache_config, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + **llm_args_extra) as llm: max_tokens = 5 sampling_params = SamplingParams(max_tokens=max_tokens, diff --git a/tests/unittest/utils/util.py b/tests/unittest/utils/util.py index 828bf08ea5..2c720a7328 100644 --- a/tests/unittest/utils/util.py +++ b/tests/unittest/utils/util.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import faulthandler import math import os @@ -391,6 +392,23 @@ def run_session(session: Session, return outputs +@contextlib.contextmanager +def altered_env(**kwargs): + old = {} + for k, v in kwargs.items(): + if k in os.environ: + old[k] = os.environ[k] + os.environ[k] = v + try: + yield + finally: + for k in kwargs: + if k not in old: + os.environ.pop(k) + else: + os.environ[k] = old[k] + + def similarity_score(a, b): "similar compare a and b " return SequenceMatcher(None, a, b).ratio()