[None][feat] Fully non-blocking pipeline parallelism executor loop. (#10349)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2026-02-10 15:43:28 +08:00 committed by GitHub
parent c233692485
commit 5f4df89109
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 450 additions and 162 deletions

View File

@ -1,5 +1,5 @@
from copy import copy, deepcopy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import torch
@ -232,6 +232,29 @@ class LogProbStorage:
class PyResult:
"""PyResult reimplements some features of `bindings.executor.Result` in Python"""
@dataclass
class Diff:
"""
Diff is used to track the changes of the PyResult.
It is designed to incrementally sync the PyResult to other ranks
by `get_diff` on one rank and `apply_diff` on other ranks.
"""
exclude_last_generation_logits: bool | None = None
context_logits_list: list[torch.Tensor] = field(default_factory=list)
generation_logits_list: list[torch.Tensor] = field(default_factory=list)
reset_log_probs: tuple[list[TokenLogprobs],
list[float] | None] | None = None
log_probs_list: list[tuple[list[TokenLogprobs], list[float]
| None]] = field(default_factory=list)
mm_embeddings: dict[str, Any] | None = None
mrope_position_ids: dict[str, Any] | None = None
mrope_position_deltas: dict[str, Any] | None = None
additional_context_outputs_list: list[tuple[str, torch.Tensor]] = field(
default_factory=list)
additional_generation_outputs_list: list[tuple[str,
torch.Tensor]] = field(
default_factory=list)
def __init__(self,
*,
prompt_len: int,
@ -277,28 +300,72 @@ class PyResult:
name: []
for name in additional_outputs
} if additional_outputs else None
self.diff = PyResult.Diff()
def reset_diff(self):
self.diff = PyResult.Diff()
def get_diff(self) -> Diff:
for i, context_logits in enumerate(self.diff.context_logits_list):
self.diff.context_logits_list[i] = context_logits.to("cpu")
for i, generation_logits in enumerate(self.diff.generation_logits_list):
self.diff.generation_logits_list[i] = generation_logits.to("cpu")
return self.diff
def apply_diff(self, diff: Diff):
if diff.exclude_last_generation_logits is not None:
self._exclude_last_generation_logits = diff.exclude_last_generation_logits
if len(diff.context_logits_list) > 0:
for context_logits in diff.context_logits_list:
self._context_logits.append(context_logits)
if len(diff.generation_logits_list) > 0:
for generation_logits in diff.generation_logits_list:
self._generation_logits.append(generation_logits)
if diff.reset_log_probs is not None:
self._log_probs.set_log_probs(*diff.reset_log_probs)
if len(diff.log_probs_list) > 0:
for log_probs, cum_log_probs in diff.log_probs_list:
self._log_probs.append(log_probs, cum_log_probs)
if diff.mm_embeddings is not None:
self._mm_embeddings = diff.mm_embeddings
if diff.mrope_position_ids is not None:
self._mrope_position_ids = diff.mrope_position_ids
self._mrope_position_deltas = diff.mrope_position_deltas
if len(diff.additional_context_outputs_list) > 0:
for name, additional_context_outputs in diff.additional_context_outputs_list:
self._additional_context_outputs[name].append(
additional_context_outputs)
if len(diff.additional_generation_outputs_list) > 0:
for name, additional_generation_outputs in diff.additional_generation_outputs_list:
self._additional_generation_outputs[name].append(
additional_generation_outputs)
def set_exclude_last_generation_logits(
self, exclude_last_generation_logits: bool):
self._exclude_last_generation_logits = exclude_last_generation_logits
self.diff.exclude_last_generation_logits = exclude_last_generation_logits
def append_context_logits(self, context_logits: torch.Tensor):
if self._context_logits:
self._context_logits.append(context_logits)
self.diff.context_logits_list.append(context_logits)
def append_generation_logits(self, generation_logits: torch.Tensor):
if self._generation_logits:
self._generation_logits.append(generation_logits)
self.diff.generation_logits_list.append(generation_logits)
def append_log_probs(self,
log_probs: list[TokenLogprobs],
cum_log_probs: Optional[list[float]] = None):
if self._log_probs:
self._log_probs.append(log_probs, cum_log_probs)
self.diff.log_probs_list.append((log_probs, cum_log_probs))
def append_mm_embeddings(self, mm_embeddings: torch.Tensor):
self._mm_embeddings = SharedTensorContainer.from_tensor(
mm_embeddings).dump_to_dict()
self.diff.mm_embeddings = self._mm_embeddings
def set_mrope_position(
self,
@ -309,6 +376,8 @@ class PyResult:
mrope_position_ids).dump_to_dict())
self._mrope_position_deltas = (SharedTensorContainer.from_tensor(
mrope_position_deltas).dump_to_dict())
self.diff.mrope_position_ids = self._mrope_position_ids
self.diff.mrope_position_deltas = self._mrope_position_deltas
def transfer_remaining_device_logits(self):
"""Finalize any remaining generation logits transfers (for chunked mode)"""
@ -319,11 +388,15 @@ class PyResult:
self, name: str, additional_context_outputs: torch.Tensor):
self._additional_context_outputs[name].append(
additional_context_outputs.to("cpu", non_blocking=True))
self.diff.additional_context_outputs_list.append(
(name, self._additional_context_outputs[name][-1]))
def append_additional_generation_outputs(
self, name: str, additional_generation_outputs: torch.Tensor):
self._additional_generation_outputs[name].append(
additional_generation_outputs.to("cpu", non_blocking=True))
self.diff.additional_generation_outputs_list.append(
(name, self._additional_generation_outputs[name][-1]))
def set_log_probs(self, log_probs: list[TokenLogprobs],
cum_log_probs: list[float]):
@ -334,6 +407,8 @@ class PyResult:
"""
if self._log_probs:
self._log_probs.set_log_probs(log_probs, cum_log_probs)
self.diff.reset_log_probs = (log_probs, cum_log_probs)
self.diff.log_probs_list.clear()
@property
def context_logits(self) -> torch.Tensor | None:

View File

@ -2,17 +2,17 @@ import dataclasses
import datetime
import functools
import os
import pickle # nosec B403
import threading
import time
import traceback
from collections import deque
from contextlib import contextmanager
from enum import IntEnum
from queue import Queue
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import torch
from tensorrt_llm._torch.expert_statistic import ExpertStatistic
from tensorrt_llm.llmapi import DisaggScheduleStyle
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
@ -21,10 +21,9 @@ try:
except ImportError:
from cuda import cudart
from tensorrt_llm._torch.pyexecutor.resource_manager import (
ResourceManagerType, request_context)
from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
mpi_disabled, nvtx_range, trace_func)
mpi_comm, mpi_disabled, nvtx_range,
set_thread_local_mpi_comm, trace_func)
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
FinishReason, InflightBatchingStats,
IterationStats, KvCacheStats,
@ -40,6 +39,7 @@ from tensorrt_llm.runtime.generation import CUASSERT
from tensorrt_llm.tools.layer_wise_benchmarks import get_calibrator
from ..distributed import Distributed
from ..expert_statistic import ExpertStatistic
from ..models.modeling_utils import DecoderModelForCausalLM
from ..modules.decoder_layer import DecoderLayer
from ..speculative.drafter import Drafter
@ -58,7 +58,8 @@ from .model_engine import ModelEngine
from .request_utils import (RequestBroadcaster, attach_py_objects_to_requests,
get_from_waiting_queue, merge_requests,
schedule_attention_dp_requests)
from .resource_manager import ResourceManager
from .resource_manager import (ResourceManager, ResourceManagerType,
request_context)
from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState,
SampleStateTensors, TRTLLMSampler)
from .scheduler import (RequestScheduler, ScheduledRequests,
@ -72,10 +73,15 @@ PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP"
# Set to a path to save detailed tracing of PyTorch operations.
PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
# Unique tag base to avoid collisions with token/logits comms
TERMINATION_COMM_TAG_BASE = 20000
PP_COMM_TAG_SCHEDULE_RESULT = 21000
PP_COMM_TAG_SAMPLE_STATE_BASE = 21001
class PPCommTag(IntEnum):
"""
Unique tags for pipeline parallelism communication.
"""
TERMINATION = 20000
SCHEDULE_RESULT = 20001
EXECUTED_BATCH_NUM = 20002
SAMPLE_STATE = 20003
@functools.cache
@ -240,6 +246,12 @@ class AsyncTransferManager:
class PyExecutor:
# Minimum number of async micro batches for async PP execution.
# This is a trade-off between memory usage and performance.
# If the number of micro batches is too small, the executor will spend too much time in synchronization.
# If the number of micro batches is too large, the executor will spend too much host memory (No additional GPU memory is required).
# 1024 in-flight micro batches can avoid synchronization in most cases and keep host memory usage low.
MIN_ASYNC_MICRO_BATCH_NUM = 1024
def __init__(self,
resource_manager,
@ -371,18 +383,20 @@ class PyExecutor:
os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0))
# list of requests in each PP micro batch
self.num_micro_batches = self.dist.pp_size
self.num_micro_batches = max(self.dist.pp_size,
self.MIN_ASYNC_MICRO_BATCH_NUM)
self.micro_batches: List[BatchStatePP
| None] = [None] * self.num_micro_batches
self.send_handles = [None] * self.num_micro_batches
# schedule handle for PP to propagate the first PP rank's schedule result
self.send_schedule_handler = None
self.send_schedule_handles = [None] * self.num_micro_batches
self.send_expected_batch_num_handles = [None] * self.num_micro_batches
self.unhandled_batch_counter = 0
self.pp_scheduler_max_retry_count = int(
os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10))
self.pp_multi_stream_sample = os.environ.get(
"TRTLLM_PP_MULTI_STREAM_SAMPLE", "1") == "1"
self.sample_stream = torch.cuda.Stream()
self.start_sample_event = torch.cuda.Event()
self.finish_sample_event = torch.cuda.Event()
if (self.dist.pp_size > 1 and self.pp_multi_stream_sample
and isinstance(self.sampler, TRTLLMSampler)):
@ -479,6 +493,13 @@ class PyExecutor:
if self.dist.pp_size > 1:
self.event_loop = self._executor_loop_pp
# `TLLM_PP_ASYNC_BROADCAST_SAMPLE_STATE` controls whether to broadcast the sample state asynchronously.
# If true, the executor loop can broadcast and handle sample states asynchronously to achieve best perf.
# If false, the executor loop can only broadcast and handle each sample state in a pre-defined iteration.
# It is only for debugging purposes.
# Some tests can disable it to get a deterministic behavior.
self.pp_async_broadcast_sample_state = os.environ.get(
"TLLM_PP_ASYNC_BROADCAST_SAMPLE_STATE", "1") == "1"
else:
self.event_loop = self._executor_loop if self.disable_overlap_scheduler else self._executor_loop_overlap
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
@ -560,7 +581,22 @@ class PyExecutor:
def start_worker(self):
with self.worker_lock:
if self.worker_started == False:
if not self.worker_started:
if self.dist.pp_size > 1:
self.executed_batch_queue: Queue[BatchStatePP] = Queue(
maxsize=self.num_micro_batches)
self.executed_batch_response_queue: Queue[
BatchStatePP] = Queue(maxsize=-1)
broadcast_sample_state_loop = self._broadcast_sample_state_loop
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
broadcast_sample_state_loop = trace_func(
broadcast_sample_state_loop)
self.broadcast_sample_state_handler = threading.Thread(
target=broadcast_sample_state_loop,
daemon=True,
name="broadcast_sample_state_handler",
)
self.broadcast_sample_state_handler.start()
self.worker_thread = threading.Thread(
target=self._event_loop_wrapper, daemon=True)
self.worker_thread.start()
@ -658,6 +694,9 @@ class PyExecutor:
logger.error("Hang detected, shutting down immediately.")
return
self.worker_thread.join()
if self.dist.pp_size > 1:
self.executed_batch_queue.put(None)
self.broadcast_sample_state_handler.join()
self.worker_started = False
for manager in self.resource_manager.resource_managers.values():
if manager:
@ -1064,41 +1103,63 @@ class PyExecutor:
def _executor_loop_cleanup(self):
for h in self.send_handles:
if h is not None:
h.wait()
for i in range(self.num_micro_batches):
self.wait_on_pp_send_handles(self.send_handles, i)
self.wait_on_pp_send_handles(self.send_schedule_handles, i)
self.wait_on_pp_send_handles(self.send_expected_batch_num_handles,
i)
with self.response_cv:
self.is_shutdown = True
self.response_cv.notify_all()
self.shutdown_event.set()
def _pp_schedule_and_propagate(self):
def _pp_schedule_and_propagate(self, microbatch_id: int):
"""The first PP rank schedules the requests and propagates the result to all other PP ranks."""
# The first PP rank schedules the requests, other ranks receive the schedule result from the previous PP rank.
if self.dist.is_first_pp_rank:
# For TP/CP cases, the first rank schedules the requests.
# For DP cases, the first PP rank schedules the requests.
scheduled_batch = None
serializable_schedule = None
is_dp_broadcast = self.dist.tp_size > 1 and self.enable_attention_dp
if self.dist.rank == 0 or (self.dist.is_first_pp_rank
and is_dp_broadcast):
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
serializable_schedule = SerializableSchedulerOutput.from_scheduler_result(
scheduled_batch, fitting_disagg_gen_init_requests,
num_fitting_reqs)
else:
# Broadcast within first tp+cp group before send/recv chain to other tp+cp groups
if self.dist.is_first_pp_rank:
if self.dist.tp_size > 1 and not self.enable_attention_dp:
with nvtx_range("tp_broadcast_schedule"):
serializable_schedule = self.dist.tp_broadcast(
serializable_schedule, root=0)
if self.dist.cp_size > 1:
with nvtx_range("cp_broadcast_schedule"):
serializable_schedule = self.dist.cp_broadcast(
serializable_schedule, root=0)
# Other ranks receive the schedule result from the previous PP rank.
if not self.dist.is_first_pp_rank:
with nvtx_range("recv_schedule_from_prev_pp"):
serializable_schedule = self.dist.recv_object(
self.dist.prev_pp_rank, PP_COMM_TAG_SCHEDULE_RESULT)
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result(
self.active_requests)
self.dist.prev_pp_rank, PPCommTag.SCHEDULE_RESULT)
# Propagate the schedule result to the next PP rank except the last PP rank.
if not self.dist.is_last_pp_rank:
if self.send_schedule_handler is not None:
with nvtx_range("wait_send_schedule_handler"):
self.send_schedule_handler.wait()
self.wait_on_pp_send_handles(self.send_schedule_handles,
microbatch_id)
with nvtx_range("send_schedule_to_next_pp"):
self.send_schedule_handler = self.dist.isend_object(
serializable_schedule, self.dist.next_pp_rank,
PP_COMM_TAG_SCHEDULE_RESULT)
self.send_schedule_handles[
microbatch_id] = self.dist.isend_object(
serializable_schedule, self.dist.next_pp_rank,
PPCommTag.SCHEDULE_RESULT)
if scheduled_batch is None:
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result(
self.active_requests)
return scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs
def _pp_retry_until_can_schedule(self, scheduled_batch):
@ -1177,8 +1238,8 @@ class PyExecutor:
# Stage 0: first PP rank schedules requests and propagates the result to all other PP ranks.
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._pp_schedule_and_propagate(
)
if not self.dist.is_first_pp_rank:
microbatch_id)
if self.dist.rank != 0:
# Retry until current rank can run first PP's schedule result.
self._pp_retry_until_can_schedule(scheduled_batch)
# Run scheduler locally because scheduler may change llm requests' state.
@ -1236,7 +1297,7 @@ class PyExecutor:
# Return the first token to the client
self._handle_first_token_response(scheduled_batch)
# Stage 1: Async forward (all ranks) and decoding pass (last rank only)
# Stage 1.1: Async forward (all ranks) and decoding pass (last rank only)
if not self.dist.is_last_pp_rank:
with torch.cuda.nvtx.range(
f"_forward_step_inter_pp pp_rank {self.dist.pp_rank}"
@ -1269,9 +1330,9 @@ class PyExecutor:
name: tensor.clone()
for name, tensor in batch_outputs.items()
}
self.start_sample_event.record()
self.sample_stream.wait_stream(
torch.cuda.current_stream())
with torch.cuda.stream(self.sample_stream):
self.start_sample_event.wait()
sample_state = self._sample_async(
scheduled_batch, batch_outputs_copy)
self.finish_sample_event.record()
@ -1302,103 +1363,222 @@ class PyExecutor:
self.micro_batches[microbatch_id] = batch_state
# sync sampler for previous microbatch to start new sample state comm chain.
prev_microbatch_id = (microbatch_id -
1) % self.num_micro_batches
previous_batch = self.micro_batches[prev_microbatch_id]
if previous_batch is not None:
# Stage 1.2: Sync sampler for previous microbatch to start new sample state comm chain.
# For last PP rank, we must synchronize the previous batch
# since we need to broadcast its sample state soon afterwards in the same iteration.
# For other PP ranks, we can delay the synchronization if the current batch cannot be queued.
previous_batch = self.previous_batch
if can_queue:
self.previous_batch = batch_state
if (self.dist.is_last_pp_rank
or can_queue) and previous_batch is not None:
with nvtx_range("sync_previous_sampler_event"):
previous_batch.sample_state.sampler_event.synchronize()
# Stage 2: Communicate sample state for previous batch between ranks
# Stage 2: Enqueue sample state for executed batch to ring broadcast it in background thread asynchronously.
# send/recv chain: (pp_size - 1) -> 0 -> 1 -> ... -> (pp_size - 2)
# intermediate ranks: send/recv sample state for next microbatch to allow overlap
offset = -1 if self.dist.is_last_pp_rank else 1
prev_microbatch_id = (microbatch_id +
offset) % self.num_micro_batches
previous_batch = self.micro_batches[prev_microbatch_id]
tag = PP_COMM_TAG_SAMPLE_STATE_BASE + prev_microbatch_id
if previous_batch is not None:
sample_state = previous_batch.sample_state
if not self.dist.is_last_pp_rank:
# Receive tokens from previous pp rank (w.r.t model forward direction)
with nvtx_range("recv_sample_state"):
sample_state.host = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=tag,
offset = -1 if self.dist.is_last_pp_rank else (
1 - self.dist.pp_size)
executed_microbatch_id = (microbatch_id +
offset) % self.num_micro_batches
executed_batch = self.micro_batches[executed_microbatch_id]
if executed_batch is not None:
self.executed_batch_queue.put(executed_batch)
self.unhandled_batch_counter += 1
self.micro_batches[executed_microbatch_id] = None
def fetch_executed_batches() -> list[BatchStatePP]:
executed_batches = []
if self.pp_async_broadcast_sample_state:
# Wait for at least one batch to finish if no new request is available.
must_get = not can_queue
else:
must_get = True
while not self.executed_batch_response_queue.empty() or (
must_get and self.unhandled_batch_counter > 0):
with nvtx_range("get_executed_batch"):
executed_batches.append(
self.executed_batch_response_queue.get())
must_get = False
return executed_batches
def ring_broadcast_executed_batch_num(
executed_batch_num: int) -> int:
if self.dist.is_first_pp_rank and self.dist.tp_size * self.dist.cp_size > 1:
with nvtx_range("tp_cp_broadcast_executed_batch_num"):
executed_batch_num = self.dist.tp_cp_broadcast(
executed_batch_num,
root=0,
)
# Send tokens to next pp rank (w.r.t model forward direction)
# Second last rank does not need to since last rank has original decoded tokens
if not self.dist.is_second_last_pp_rank:
self.wait_on_pp_send_handles(prev_microbatch_id)
with nvtx_range("send_sample_state"):
self.send_handles[
prev_microbatch_id] = self.dist.isend_object(
sample_state.host,
if not self.dist.is_first_pp_rank:
with nvtx_range("recv_expected_batch_num"):
executed_batch_num = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=PPCommTag.EXECUTED_BATCH_NUM,
)
if not self.dist.is_last_pp_rank:
self.wait_on_pp_send_handles(
self.send_expected_batch_num_handles, microbatch_id)
with nvtx_range("send_expected_batch_num"):
self.send_expected_batch_num_handles[
microbatch_id] = self.dist.isend_object(
executed_batch_num,
dest=self.dist.next_pp_rank,
tag=tag)
tag=PPCommTag.EXECUTED_BATCH_NUM,
)
return executed_batch_num
# Stage 3: Finalize previous batch that finished sample state communication
# In last pp rank, stage 2 and 3 process different previous batches
prev_microbatch_id = (microbatch_id +
1) % self.num_micro_batches
previous_batch = self.micro_batches[prev_microbatch_id]
finished_requests = []
if previous_batch is not None:
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
sample_state = previous_batch.sample_state
sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs
self._update_requests(previous_batch.sample_state)
def handle_executed_batches(executed_batch_num: int):
if self.dist.rank != 0:
dequeue_counter = 0
while dequeue_counter < executed_batch_num:
with nvtx_range("get_executed_batch"):
executed_batch = self.executed_batch_response_queue.get(
)
self._handle_executed_batch(executed_batch)
dequeue_counter += 1
else:
for executed_batch in executed_batches:
self._handle_executed_batch(executed_batch)
self.unhandled_batch_counter -= executed_batch_num
if self.kv_cache_transceiver:
self._send_kv_async(
previous_batch.finished_ctx_reqs)
self._handle_canceled_requests()
executed_batch_num = 0
self._handle_logits_communication(
previous_batch, prev_microbatch_id)
# Stage 3.1: The first rank determines the number of executed batches.
if self.dist.rank == 0:
executed_batches = fetch_executed_batches()
executed_batch_num = len(executed_batches)
finished_requests = self._handle_responses()
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
attn_metadata = getattr(self.model_engine,
'attn_metadata', None)
kv_cache_dtype_byte_size = getattr(
self.model_engine, 'kv_cache_dtype_byte_size', None)
self.resource_manager.update_resources(
previous_scheduled_batch, attn_metadata,
kv_cache_dtype_byte_size)
# Stage 3.2: Broadcast the number of executed batches to other ranks.
executed_batch_num = ring_broadcast_executed_batch_num(
executed_batch_num)
self._remove_inflight_ids(previous_batch)
# Stage 3.3: Handle executed batches.
handle_executed_batches(executed_batch_num)
self.wait_on_pp_send_handles(prev_microbatch_id)
self.micro_batches[prev_microbatch_id] = None
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
):
self._check_kv_transfer_timeout()
if self._disagg_pp_termination_handler is not None:
self._disagg_pp_termination_handler.terminate_pending_requests(
)
# march forward in microbatch slots
# Stage 4: March forward in microbatch slots
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
if self.enable_iter_perf_stats and previous_batch is not None:
sample_state = previous_batch.sample_state
sample_state.scheduled_requests.context_requests = previous_batch.scheduled_ctx_reqs
self._process_iter_stats(finished_requests,
self.active_requests,
previous_batch, microbatch_id)
self.iter_counter += 1
# Stage 5: Handle remaining executed batches in the queue.
while self.unhandled_batch_counter > 0:
with nvtx_range("get_executed_batch"):
executed_batch = self.executed_batch_response_queue.get()
self._handle_executed_batch(executed_batch)
self.unhandled_batch_counter -= 1
def _broadcast_sample_state_loop(self):
logger.debug(
f"Starting broadcast sample state loop for pp_rank {self.dist.pp_rank}"
)
torch.cuda.set_device(self.device_id)
# ensure the context is created, otherwise, some MPI calls will fail.
CUASSERT(cudart.cudaSetDevice(self.device_id))
# Acquiring pkl5.Intracomm's send/recv locks from both executor loop thread
# and this thread will cause perf drop and even deadlock.
# We create new MPI comm to avoid these issues.
logger.info(
"Create new MPI comm for broadcast sample state thread to avoid deadlock."
)
new_mpi_comm = mpi_comm().Dup()
set_thread_local_mpi_comm(new_mpi_comm)
while True:
executed_batch = self.executed_batch_queue.get()
if executed_batch is None:
break
self._ring_broadcast_sample_state(executed_batch)
set_thread_local_mpi_comm(None)
new_mpi_comm.Free()
def _ring_broadcast_sample_state(
self,
executed_batch: Optional[BatchStatePP],
) -> None:
if executed_batch is None:
return
tag = PPCommTag.SAMPLE_STATE
microbatch_id = executed_batch.microbatch_id
sample_state = executed_batch.sample_state
requests = sample_state.scheduled_requests.all_requests()
if not self.dist.is_last_pp_rank:
# Receive tokens from previous pp rank (w.r.t model forward direction)
with nvtx_range("recv_sample_state"):
sample_state.host, py_result_diffs = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=tag,
)
for request, py_result_diff in zip(requests, py_result_diffs):
request.py_result.apply_diff(py_result_diff)
self.executed_batch_response_queue.put(executed_batch)
# Send tokens to next pp rank (w.r.t model forward direction)
# Second last rank does not need to since last rank has original decoded tokens
if not self.dist.is_second_last_pp_rank:
py_result_diffs = []
for request in requests:
diff = request.py_result.get_diff()
py_result_diffs.append(diff)
request.py_result.reset_diff()
self.wait_on_pp_send_handles(self.send_handles, microbatch_id)
with nvtx_range("send_sample_state"):
self.send_handles[microbatch_id] = self.dist.isend_object(
(sample_state.host, py_result_diffs),
dest=self.dist.next_pp_rank,
tag=tag,
)
def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
finished_requests = []
if executed_batch is not None:
with torch.cuda.nvtx.range("_handle_executed_batch_pp"):
sample_state = executed_batch.sample_state
sample_state.scheduled_requests.context_requests = executed_batch.finished_ctx_reqs
self._update_requests(executed_batch.sample_state)
if self.kv_cache_transceiver:
self._send_kv_async(executed_batch.finished_ctx_reqs)
self._handle_canceled_requests()
finished_requests = self._handle_responses()
previous_scheduled_batch = executed_batch.sample_state.scheduled_requests
attn_metadata = getattr(self.model_engine, 'attn_metadata',
None)
kv_cache_dtype_byte_size = getattr(self.model_engine,
'kv_cache_dtype_byte_size',
None)
self.resource_manager.update_resources(
previous_scheduled_batch, attn_metadata,
kv_cache_dtype_byte_size)
self._remove_inflight_ids(executed_batch)
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
):
self._check_kv_transfer_timeout()
if self._disagg_pp_termination_handler is not None:
self._disagg_pp_termination_handler.terminate_pending_requests()
if self.enable_iter_perf_stats and executed_batch is not None:
sample_state = executed_batch.sample_state
sample_state.scheduled_requests.context_requests = executed_batch.scheduled_ctx_reqs
self._process_iter_stats(
finished_requests,
self.active_requests,
executed_batch,
executed_batch.microbatch_id % self.dist.pp_size,
)
@nvtx_range("wait_on_pp_send_handles")
def wait_on_pp_send_handles(self, microbatch_id):
if self.send_handles[microbatch_id] is not None:
self.send_handles[microbatch_id].wait()
self.send_handles[microbatch_id] = None
def wait_on_pp_send_handles(self, send_handles, microbatch_id):
if send_handles[microbatch_id] is not None:
send_handles[microbatch_id].wait()
send_handles[microbatch_id] = None
def _can_queue(self, scheduled_batch):
@ -3022,41 +3202,6 @@ class PyExecutor:
self._terminate_request(request)
return requests_to_terminate
def _handle_logits_communication(self, previous_batch, prev_microbatch_id):
"""Handle logits communication between pipeline parallel ranks.
If logits were requested, the last PP rank sends to the first PP rank (who sends responses)
the logits of the requests that have finished.
Args:
previous_batch: The previous batch state
prev_microbatch_id: The microbatch ID for the previous batch
"""
# NOTE: If the rank processing the logits ever becomes the same as
# the rank sending the responses, this code can be removed.
finished_reqs = [
r for r in
previous_batch.sample_state.scheduled_requests.all_requests()
if r.state == LlmRequestState.GENERATION_COMPLETE and (
r.py_return_context_logits or r.py_return_generation_logits
or r.py_additional_outputs is not None)
]
if self.dist.is_first_pp_rank and len(finished_reqs):
finished_reqs_py_results = [r.py_result for r in finished_reqs]
finished_reqs_py_results = self.dist.recv_object(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id,
)
for req, py_result in zip(finished_reqs, finished_reqs_py_results):
req.py_result = py_result
elif self.dist.is_last_pp_rank and len(finished_reqs):
self.wait_on_pp_send_handles(prev_microbatch_id)
self.send_handles[prev_microbatch_id] = self.dist.isend_object(
[r.py_result for r in finished_reqs],
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)
def _await_any_response(self,
timeout: Optional[float] = None
) -> List[LlmResponse]:
@ -3198,7 +3343,7 @@ class DisaggPPTerminationHandler:
self._pending_termination = {}
self._terminating_iteration = 0
self._send_handle = None
self._comm_tag = TERMINATION_COMM_TAG_BASE
self._comm_tag = PPCommTag.TERMINATION
def terminate(self, request: LlmRequest):
self._pending_termination[request.py_request_id] = request

View File

@ -673,7 +673,8 @@ class RequestBroadcaster:
# Broadcast within first PP stage before send/recv chain to other PP stages.
# This needs to cover both TP and CP ranks within the first PP stage.
if self.dist.is_first_pp_rank:
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
with nvtx_range("tp_broadcast_requests"):
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
# Tag for communication
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts

View File

@ -8,7 +8,8 @@ from typing import Dict, List
import torch
from torch.nn import functional as F
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor,
torch_dtype_to_str)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.math_utils import ceil_div, pad_up
from tensorrt_llm.quantization.utils import fp4_utils
@ -416,6 +417,29 @@ def relu2(x: torch.Tensor) -> torch.Tensor:
return torch.square(F.relu(x))
def tensor_to_str(x: torch.Tensor, num_elements: int = 10) -> str:
# Pass num_elements=-1 will print the whole tensor
if num_elements < 0:
num_elements = torch.numel(x)
if x.dtype in (torch.int32, torch.int64):
float_x = x.to(dtype=float)
else:
float_x = x
return ("Tensor("
f"shape={tuple(x.shape)}, "
f"dtype={torch_dtype_to_str(x.dtype)}, "
f"device={x.device}, "
f"stats=("
f"abs_mean={float_x.abs().mean().item():.3f}, "
f"mean={float_x.mean().item():.3f}, "
f"std={float_x.std().item():.3f}, "
f"max={x.max().item():.3f}, "
f"min={x.min().item():.3f}"
"), "
f"values={x.flatten()[:num_elements].tolist()}"
")")
@maybe_compile
def maybe_compiled_copy_(dst, src):
dst.copy_(src)

View File

@ -23,6 +23,7 @@ import socket
import struct
import sys
import tempfile
import threading
import trace
import traceback
import weakref
@ -509,7 +510,17 @@ def set_mpi_comm(new_comm):
comm = new_comm
thread_local_comm = threading.local()
def set_thread_local_mpi_comm(new_comm):
thread_local_comm.value = new_comm
def mpi_comm():
if hasattr(thread_local_comm,
"value") and thread_local_comm.value is not None:
return thread_local_comm.value
return comm

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple
import yaml
from mpi4py.MPI import COMM_WORLD, Comm
from mpi4py.util import pkl5
from .._utils import global_mpi_rank, global_mpi_size
@ -327,7 +328,7 @@ def split_world_comm(
f"global_rank: {global_rank}, instance_idx: {instance_idx}, sub_rank: {sub_rank}, is_leader: {is_leader}"
)
return is_leader, instance_idx, sub_comm
return is_leader, instance_idx, pkl5.Intracomm(sub_comm)
def parse_metadata_server_config_file(

View File

@ -170,8 +170,14 @@ class MpiPoolSession(MpiSession):
def _start_mpi_pool(self):
assert not self.mpi_pool, 'MPI session already started'
env = {
key: value
for key, value in os.environ.items()
if key.startswith("TRTLLM") or key.startswith("TLLM")
}
self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers,
path=sys.path)
path=sys.path,
env=env)
def __del__(self):
self.shutdown_abort()

View File

@ -1,4 +1,5 @@
import asyncio
import contextlib
import datetime
import gc
import json
@ -57,7 +58,8 @@ from llmapi.lora_test_utils import (
check_llama_7b_multi_lora_from_request_test_harness,
check_llama_7b_multi_unique_lora_adapters_from_request)
from utils.llm_data import llm_models_root
from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_pre_hopper, skip_single_gpu
from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_pre_hopper, skip_single_gpu, altered_env
# isort: on
# The unittests are based on the tiny-llama, which is fast to build and run.
@ -2171,11 +2173,16 @@ def llm_get_stats_test_harness(tp_size: int = 1,
llm_args_extra["fast_build"] = True
LLM_CLASS = LLM
with LLM_CLASS(model=llama_model_path,
kv_cache_config=global_kvcache_config,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
**llm_args_extra) as llm:
# Since we need to check pp's internal states, we disable the async broadcast
# to get a deterministic behavior.
env_ctx = altered_env(TLLM_PP_ASYNC_BROADCAST_SAMPLE_STATE="0") \
if pp_size > 1 else contextlib.nullcontext()
with env_ctx, LLM_CLASS(model=llama_model_path,
kv_cache_config=global_kvcache_config,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
**llm_args_extra) as llm:
max_tokens = 5
sampling_params = SamplingParams(max_tokens=max_tokens,

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import faulthandler
import math
import os
@ -391,6 +392,23 @@ def run_session(session: Session,
return outputs
@contextlib.contextmanager
def altered_env(**kwargs):
old = {}
for k, v in kwargs.items():
if k in os.environ:
old[k] = os.environ[k]
os.environ[k] = v
try:
yield
finally:
for k in kwargs:
if k not in old:
os.environ.pop(k)
else:
os.environ[k] = old[k]
def similarity_score(a, b):
"similar compare a and b "
return SequenceMatcher(None, a, b).ratio()