mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][feat] Fully non-blocking pipeline parallelism executor loop. (#10349)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
c233692485
commit
5f4df89109
@ -1,5 +1,5 @@
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -232,6 +232,29 @@ class LogProbStorage:
|
||||
class PyResult:
|
||||
"""PyResult reimplements some features of `bindings.executor.Result` in Python"""
|
||||
|
||||
@dataclass
|
||||
class Diff:
|
||||
"""
|
||||
Diff is used to track the changes of the PyResult.
|
||||
It is designed to incrementally sync the PyResult to other ranks
|
||||
by `get_diff` on one rank and `apply_diff` on other ranks.
|
||||
"""
|
||||
exclude_last_generation_logits: bool | None = None
|
||||
context_logits_list: list[torch.Tensor] = field(default_factory=list)
|
||||
generation_logits_list: list[torch.Tensor] = field(default_factory=list)
|
||||
reset_log_probs: tuple[list[TokenLogprobs],
|
||||
list[float] | None] | None = None
|
||||
log_probs_list: list[tuple[list[TokenLogprobs], list[float]
|
||||
| None]] = field(default_factory=list)
|
||||
mm_embeddings: dict[str, Any] | None = None
|
||||
mrope_position_ids: dict[str, Any] | None = None
|
||||
mrope_position_deltas: dict[str, Any] | None = None
|
||||
additional_context_outputs_list: list[tuple[str, torch.Tensor]] = field(
|
||||
default_factory=list)
|
||||
additional_generation_outputs_list: list[tuple[str,
|
||||
torch.Tensor]] = field(
|
||||
default_factory=list)
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
prompt_len: int,
|
||||
@ -277,28 +300,72 @@ class PyResult:
|
||||
name: []
|
||||
for name in additional_outputs
|
||||
} if additional_outputs else None
|
||||
self.diff = PyResult.Diff()
|
||||
|
||||
def reset_diff(self):
|
||||
self.diff = PyResult.Diff()
|
||||
|
||||
def get_diff(self) -> Diff:
|
||||
for i, context_logits in enumerate(self.diff.context_logits_list):
|
||||
self.diff.context_logits_list[i] = context_logits.to("cpu")
|
||||
for i, generation_logits in enumerate(self.diff.generation_logits_list):
|
||||
self.diff.generation_logits_list[i] = generation_logits.to("cpu")
|
||||
return self.diff
|
||||
|
||||
def apply_diff(self, diff: Diff):
|
||||
if diff.exclude_last_generation_logits is not None:
|
||||
self._exclude_last_generation_logits = diff.exclude_last_generation_logits
|
||||
if len(diff.context_logits_list) > 0:
|
||||
for context_logits in diff.context_logits_list:
|
||||
self._context_logits.append(context_logits)
|
||||
if len(diff.generation_logits_list) > 0:
|
||||
for generation_logits in diff.generation_logits_list:
|
||||
self._generation_logits.append(generation_logits)
|
||||
if diff.reset_log_probs is not None:
|
||||
self._log_probs.set_log_probs(*diff.reset_log_probs)
|
||||
if len(diff.log_probs_list) > 0:
|
||||
for log_probs, cum_log_probs in diff.log_probs_list:
|
||||
self._log_probs.append(log_probs, cum_log_probs)
|
||||
if diff.mm_embeddings is not None:
|
||||
self._mm_embeddings = diff.mm_embeddings
|
||||
if diff.mrope_position_ids is not None:
|
||||
self._mrope_position_ids = diff.mrope_position_ids
|
||||
self._mrope_position_deltas = diff.mrope_position_deltas
|
||||
if len(diff.additional_context_outputs_list) > 0:
|
||||
for name, additional_context_outputs in diff.additional_context_outputs_list:
|
||||
self._additional_context_outputs[name].append(
|
||||
additional_context_outputs)
|
||||
if len(diff.additional_generation_outputs_list) > 0:
|
||||
for name, additional_generation_outputs in diff.additional_generation_outputs_list:
|
||||
self._additional_generation_outputs[name].append(
|
||||
additional_generation_outputs)
|
||||
|
||||
def set_exclude_last_generation_logits(
|
||||
self, exclude_last_generation_logits: bool):
|
||||
self._exclude_last_generation_logits = exclude_last_generation_logits
|
||||
self.diff.exclude_last_generation_logits = exclude_last_generation_logits
|
||||
|
||||
def append_context_logits(self, context_logits: torch.Tensor):
|
||||
if self._context_logits:
|
||||
self._context_logits.append(context_logits)
|
||||
self.diff.context_logits_list.append(context_logits)
|
||||
|
||||
def append_generation_logits(self, generation_logits: torch.Tensor):
|
||||
if self._generation_logits:
|
||||
self._generation_logits.append(generation_logits)
|
||||
self.diff.generation_logits_list.append(generation_logits)
|
||||
|
||||
def append_log_probs(self,
|
||||
log_probs: list[TokenLogprobs],
|
||||
cum_log_probs: Optional[list[float]] = None):
|
||||
if self._log_probs:
|
||||
self._log_probs.append(log_probs, cum_log_probs)
|
||||
self.diff.log_probs_list.append((log_probs, cum_log_probs))
|
||||
|
||||
def append_mm_embeddings(self, mm_embeddings: torch.Tensor):
|
||||
self._mm_embeddings = SharedTensorContainer.from_tensor(
|
||||
mm_embeddings).dump_to_dict()
|
||||
self.diff.mm_embeddings = self._mm_embeddings
|
||||
|
||||
def set_mrope_position(
|
||||
self,
|
||||
@ -309,6 +376,8 @@ class PyResult:
|
||||
mrope_position_ids).dump_to_dict())
|
||||
self._mrope_position_deltas = (SharedTensorContainer.from_tensor(
|
||||
mrope_position_deltas).dump_to_dict())
|
||||
self.diff.mrope_position_ids = self._mrope_position_ids
|
||||
self.diff.mrope_position_deltas = self._mrope_position_deltas
|
||||
|
||||
def transfer_remaining_device_logits(self):
|
||||
"""Finalize any remaining generation logits transfers (for chunked mode)"""
|
||||
@ -319,11 +388,15 @@ class PyResult:
|
||||
self, name: str, additional_context_outputs: torch.Tensor):
|
||||
self._additional_context_outputs[name].append(
|
||||
additional_context_outputs.to("cpu", non_blocking=True))
|
||||
self.diff.additional_context_outputs_list.append(
|
||||
(name, self._additional_context_outputs[name][-1]))
|
||||
|
||||
def append_additional_generation_outputs(
|
||||
self, name: str, additional_generation_outputs: torch.Tensor):
|
||||
self._additional_generation_outputs[name].append(
|
||||
additional_generation_outputs.to("cpu", non_blocking=True))
|
||||
self.diff.additional_generation_outputs_list.append(
|
||||
(name, self._additional_generation_outputs[name][-1]))
|
||||
|
||||
def set_log_probs(self, log_probs: list[TokenLogprobs],
|
||||
cum_log_probs: list[float]):
|
||||
@ -334,6 +407,8 @@ class PyResult:
|
||||
"""
|
||||
if self._log_probs:
|
||||
self._log_probs.set_log_probs(log_probs, cum_log_probs)
|
||||
self.diff.reset_log_probs = (log_probs, cum_log_probs)
|
||||
self.diff.log_probs_list.clear()
|
||||
|
||||
@property
|
||||
def context_logits(self) -> torch.Tensor | None:
|
||||
|
||||
@ -2,17 +2,17 @@ import dataclasses
|
||||
import datetime
|
||||
import functools
|
||||
import os
|
||||
import pickle # nosec B403
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from enum import IntEnum
|
||||
from queue import Queue
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.expert_statistic import ExpertStatistic
|
||||
from tensorrt_llm.llmapi import DisaggScheduleStyle
|
||||
from tensorrt_llm.serve.responses_utils import get_steady_clock_now_in_seconds
|
||||
|
||||
@ -21,10 +21,9 @@ try:
|
||||
except ImportError:
|
||||
from cuda import cudart
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import (
|
||||
ResourceManagerType, request_context)
|
||||
from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
|
||||
mpi_disabled, nvtx_range, trace_func)
|
||||
mpi_comm, mpi_disabled, nvtx_range,
|
||||
set_thread_local_mpi_comm, trace_func)
|
||||
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
|
||||
FinishReason, InflightBatchingStats,
|
||||
IterationStats, KvCacheStats,
|
||||
@ -40,6 +39,7 @@ from tensorrt_llm.runtime.generation import CUASSERT
|
||||
from tensorrt_llm.tools.layer_wise_benchmarks import get_calibrator
|
||||
|
||||
from ..distributed import Distributed
|
||||
from ..expert_statistic import ExpertStatistic
|
||||
from ..models.modeling_utils import DecoderModelForCausalLM
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..speculative.drafter import Drafter
|
||||
@ -58,7 +58,8 @@ from .model_engine import ModelEngine
|
||||
from .request_utils import (RequestBroadcaster, attach_py_objects_to_requests,
|
||||
get_from_waiting_queue, merge_requests,
|
||||
schedule_attention_dp_requests)
|
||||
from .resource_manager import ResourceManager
|
||||
from .resource_manager import (ResourceManager, ResourceManagerType,
|
||||
request_context)
|
||||
from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState,
|
||||
SampleStateTensors, TRTLLMSampler)
|
||||
from .scheduler import (RequestScheduler, ScheduledRequests,
|
||||
@ -72,10 +73,15 @@ PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP"
|
||||
# Set to a path to save detailed tracing of PyTorch operations.
|
||||
PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
|
||||
|
||||
# Unique tag base to avoid collisions with token/logits comms
|
||||
TERMINATION_COMM_TAG_BASE = 20000
|
||||
PP_COMM_TAG_SCHEDULE_RESULT = 21000
|
||||
PP_COMM_TAG_SAMPLE_STATE_BASE = 21001
|
||||
|
||||
class PPCommTag(IntEnum):
|
||||
"""
|
||||
Unique tags for pipeline parallelism communication.
|
||||
"""
|
||||
TERMINATION = 20000
|
||||
SCHEDULE_RESULT = 20001
|
||||
EXECUTED_BATCH_NUM = 20002
|
||||
SAMPLE_STATE = 20003
|
||||
|
||||
|
||||
@functools.cache
|
||||
@ -240,6 +246,12 @@ class AsyncTransferManager:
|
||||
|
||||
|
||||
class PyExecutor:
|
||||
# Minimum number of async micro batches for async PP execution.
|
||||
# This is a trade-off between memory usage and performance.
|
||||
# If the number of micro batches is too small, the executor will spend too much time in synchronization.
|
||||
# If the number of micro batches is too large, the executor will spend too much host memory (No additional GPU memory is required).
|
||||
# 1024 in-flight micro batches can avoid synchronization in most cases and keep host memory usage low.
|
||||
MIN_ASYNC_MICRO_BATCH_NUM = 1024
|
||||
|
||||
def __init__(self,
|
||||
resource_manager,
|
||||
@ -371,18 +383,20 @@ class PyExecutor:
|
||||
os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0))
|
||||
|
||||
# list of requests in each PP micro batch
|
||||
self.num_micro_batches = self.dist.pp_size
|
||||
self.num_micro_batches = max(self.dist.pp_size,
|
||||
self.MIN_ASYNC_MICRO_BATCH_NUM)
|
||||
self.micro_batches: List[BatchStatePP
|
||||
| None] = [None] * self.num_micro_batches
|
||||
self.send_handles = [None] * self.num_micro_batches
|
||||
# schedule handle for PP to propagate the first PP rank's schedule result
|
||||
self.send_schedule_handler = None
|
||||
self.send_schedule_handles = [None] * self.num_micro_batches
|
||||
self.send_expected_batch_num_handles = [None] * self.num_micro_batches
|
||||
self.unhandled_batch_counter = 0
|
||||
self.pp_scheduler_max_retry_count = int(
|
||||
os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10))
|
||||
self.pp_multi_stream_sample = os.environ.get(
|
||||
"TRTLLM_PP_MULTI_STREAM_SAMPLE", "1") == "1"
|
||||
self.sample_stream = torch.cuda.Stream()
|
||||
self.start_sample_event = torch.cuda.Event()
|
||||
self.finish_sample_event = torch.cuda.Event()
|
||||
if (self.dist.pp_size > 1 and self.pp_multi_stream_sample
|
||||
and isinstance(self.sampler, TRTLLMSampler)):
|
||||
@ -479,6 +493,13 @@ class PyExecutor:
|
||||
|
||||
if self.dist.pp_size > 1:
|
||||
self.event_loop = self._executor_loop_pp
|
||||
# `TLLM_PP_ASYNC_BROADCAST_SAMPLE_STATE` controls whether to broadcast the sample state asynchronously.
|
||||
# If true, the executor loop can broadcast and handle sample states asynchronously to achieve best perf.
|
||||
# If false, the executor loop can only broadcast and handle each sample state in a pre-defined iteration.
|
||||
# It is only for debugging purposes.
|
||||
# Some tests can disable it to get a deterministic behavior.
|
||||
self.pp_async_broadcast_sample_state = os.environ.get(
|
||||
"TLLM_PP_ASYNC_BROADCAST_SAMPLE_STATE", "1") == "1"
|
||||
else:
|
||||
self.event_loop = self._executor_loop if self.disable_overlap_scheduler else self._executor_loop_overlap
|
||||
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
|
||||
@ -560,7 +581,22 @@ class PyExecutor:
|
||||
|
||||
def start_worker(self):
|
||||
with self.worker_lock:
|
||||
if self.worker_started == False:
|
||||
if not self.worker_started:
|
||||
if self.dist.pp_size > 1:
|
||||
self.executed_batch_queue: Queue[BatchStatePP] = Queue(
|
||||
maxsize=self.num_micro_batches)
|
||||
self.executed_batch_response_queue: Queue[
|
||||
BatchStatePP] = Queue(maxsize=-1)
|
||||
broadcast_sample_state_loop = self._broadcast_sample_state_loop
|
||||
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
|
||||
broadcast_sample_state_loop = trace_func(
|
||||
broadcast_sample_state_loop)
|
||||
self.broadcast_sample_state_handler = threading.Thread(
|
||||
target=broadcast_sample_state_loop,
|
||||
daemon=True,
|
||||
name="broadcast_sample_state_handler",
|
||||
)
|
||||
self.broadcast_sample_state_handler.start()
|
||||
self.worker_thread = threading.Thread(
|
||||
target=self._event_loop_wrapper, daemon=True)
|
||||
self.worker_thread.start()
|
||||
@ -658,6 +694,9 @@ class PyExecutor:
|
||||
logger.error("Hang detected, shutting down immediately.")
|
||||
return
|
||||
self.worker_thread.join()
|
||||
if self.dist.pp_size > 1:
|
||||
self.executed_batch_queue.put(None)
|
||||
self.broadcast_sample_state_handler.join()
|
||||
self.worker_started = False
|
||||
for manager in self.resource_manager.resource_managers.values():
|
||||
if manager:
|
||||
@ -1064,41 +1103,63 @@ class PyExecutor:
|
||||
|
||||
def _executor_loop_cleanup(self):
|
||||
|
||||
for h in self.send_handles:
|
||||
if h is not None:
|
||||
h.wait()
|
||||
for i in range(self.num_micro_batches):
|
||||
self.wait_on_pp_send_handles(self.send_handles, i)
|
||||
self.wait_on_pp_send_handles(self.send_schedule_handles, i)
|
||||
self.wait_on_pp_send_handles(self.send_expected_batch_num_handles,
|
||||
i)
|
||||
|
||||
with self.response_cv:
|
||||
self.is_shutdown = True
|
||||
self.response_cv.notify_all()
|
||||
self.shutdown_event.set()
|
||||
|
||||
def _pp_schedule_and_propagate(self):
|
||||
def _pp_schedule_and_propagate(self, microbatch_id: int):
|
||||
"""The first PP rank schedules the requests and propagates the result to all other PP ranks."""
|
||||
|
||||
# The first PP rank schedules the requests, other ranks receive the schedule result from the previous PP rank.
|
||||
if self.dist.is_first_pp_rank:
|
||||
# For TP/CP cases, the first rank schedules the requests.
|
||||
# For DP cases, the first PP rank schedules the requests.
|
||||
scheduled_batch = None
|
||||
serializable_schedule = None
|
||||
is_dp_broadcast = self.dist.tp_size > 1 and self.enable_attention_dp
|
||||
if self.dist.rank == 0 or (self.dist.is_first_pp_rank
|
||||
and is_dp_broadcast):
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
|
||||
)
|
||||
serializable_schedule = SerializableSchedulerOutput.from_scheduler_result(
|
||||
scheduled_batch, fitting_disagg_gen_init_requests,
|
||||
num_fitting_reqs)
|
||||
else:
|
||||
|
||||
# Broadcast within first tp+cp group before send/recv chain to other tp+cp groups
|
||||
if self.dist.is_first_pp_rank:
|
||||
if self.dist.tp_size > 1 and not self.enable_attention_dp:
|
||||
with nvtx_range("tp_broadcast_schedule"):
|
||||
serializable_schedule = self.dist.tp_broadcast(
|
||||
serializable_schedule, root=0)
|
||||
if self.dist.cp_size > 1:
|
||||
with nvtx_range("cp_broadcast_schedule"):
|
||||
serializable_schedule = self.dist.cp_broadcast(
|
||||
serializable_schedule, root=0)
|
||||
|
||||
# Other ranks receive the schedule result from the previous PP rank.
|
||||
if not self.dist.is_first_pp_rank:
|
||||
with nvtx_range("recv_schedule_from_prev_pp"):
|
||||
serializable_schedule = self.dist.recv_object(
|
||||
self.dist.prev_pp_rank, PP_COMM_TAG_SCHEDULE_RESULT)
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result(
|
||||
self.active_requests)
|
||||
self.dist.prev_pp_rank, PPCommTag.SCHEDULE_RESULT)
|
||||
|
||||
# Propagate the schedule result to the next PP rank except the last PP rank.
|
||||
if not self.dist.is_last_pp_rank:
|
||||
if self.send_schedule_handler is not None:
|
||||
with nvtx_range("wait_send_schedule_handler"):
|
||||
self.send_schedule_handler.wait()
|
||||
self.wait_on_pp_send_handles(self.send_schedule_handles,
|
||||
microbatch_id)
|
||||
with nvtx_range("send_schedule_to_next_pp"):
|
||||
self.send_schedule_handler = self.dist.isend_object(
|
||||
serializable_schedule, self.dist.next_pp_rank,
|
||||
PP_COMM_TAG_SCHEDULE_RESULT)
|
||||
self.send_schedule_handles[
|
||||
microbatch_id] = self.dist.isend_object(
|
||||
serializable_schedule, self.dist.next_pp_rank,
|
||||
PPCommTag.SCHEDULE_RESULT)
|
||||
|
||||
if scheduled_batch is None:
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result(
|
||||
self.active_requests)
|
||||
return scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs
|
||||
|
||||
def _pp_retry_until_can_schedule(self, scheduled_batch):
|
||||
@ -1177,8 +1238,8 @@ class PyExecutor:
|
||||
|
||||
# Stage 0: first PP rank schedules requests and propagates the result to all other PP ranks.
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._pp_schedule_and_propagate(
|
||||
)
|
||||
if not self.dist.is_first_pp_rank:
|
||||
microbatch_id)
|
||||
if self.dist.rank != 0:
|
||||
# Retry until current rank can run first PP's schedule result.
|
||||
self._pp_retry_until_can_schedule(scheduled_batch)
|
||||
# Run scheduler locally because scheduler may change llm requests' state.
|
||||
@ -1236,7 +1297,7 @@ class PyExecutor:
|
||||
# Return the first token to the client
|
||||
self._handle_first_token_response(scheduled_batch)
|
||||
|
||||
# Stage 1: Async forward (all ranks) and decoding pass (last rank only)
|
||||
# Stage 1.1: Async forward (all ranks) and decoding pass (last rank only)
|
||||
if not self.dist.is_last_pp_rank:
|
||||
with torch.cuda.nvtx.range(
|
||||
f"_forward_step_inter_pp pp_rank {self.dist.pp_rank}"
|
||||
@ -1269,9 +1330,9 @@ class PyExecutor:
|
||||
name: tensor.clone()
|
||||
for name, tensor in batch_outputs.items()
|
||||
}
|
||||
self.start_sample_event.record()
|
||||
self.sample_stream.wait_stream(
|
||||
torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.sample_stream):
|
||||
self.start_sample_event.wait()
|
||||
sample_state = self._sample_async(
|
||||
scheduled_batch, batch_outputs_copy)
|
||||
self.finish_sample_event.record()
|
||||
@ -1302,103 +1363,222 @@ class PyExecutor:
|
||||
|
||||
self.micro_batches[microbatch_id] = batch_state
|
||||
|
||||
# sync sampler for previous microbatch to start new sample state comm chain.
|
||||
prev_microbatch_id = (microbatch_id -
|
||||
1) % self.num_micro_batches
|
||||
previous_batch = self.micro_batches[prev_microbatch_id]
|
||||
if previous_batch is not None:
|
||||
# Stage 1.2: Sync sampler for previous microbatch to start new sample state comm chain.
|
||||
# For last PP rank, we must synchronize the previous batch
|
||||
# since we need to broadcast its sample state soon afterwards in the same iteration.
|
||||
# For other PP ranks, we can delay the synchronization if the current batch cannot be queued.
|
||||
previous_batch = self.previous_batch
|
||||
if can_queue:
|
||||
self.previous_batch = batch_state
|
||||
if (self.dist.is_last_pp_rank
|
||||
or can_queue) and previous_batch is not None:
|
||||
with nvtx_range("sync_previous_sampler_event"):
|
||||
previous_batch.sample_state.sampler_event.synchronize()
|
||||
|
||||
# Stage 2: Communicate sample state for previous batch between ranks
|
||||
# Stage 2: Enqueue sample state for executed batch to ring broadcast it in background thread asynchronously.
|
||||
# send/recv chain: (pp_size - 1) -> 0 -> 1 -> ... -> (pp_size - 2)
|
||||
# intermediate ranks: send/recv sample state for next microbatch to allow overlap
|
||||
offset = -1 if self.dist.is_last_pp_rank else 1
|
||||
prev_microbatch_id = (microbatch_id +
|
||||
offset) % self.num_micro_batches
|
||||
previous_batch = self.micro_batches[prev_microbatch_id]
|
||||
tag = PP_COMM_TAG_SAMPLE_STATE_BASE + prev_microbatch_id
|
||||
if previous_batch is not None:
|
||||
sample_state = previous_batch.sample_state
|
||||
if not self.dist.is_last_pp_rank:
|
||||
# Receive tokens from previous pp rank (w.r.t model forward direction)
|
||||
with nvtx_range("recv_sample_state"):
|
||||
sample_state.host = self.dist.recv_object(
|
||||
src=self.dist.prev_pp_rank,
|
||||
tag=tag,
|
||||
offset = -1 if self.dist.is_last_pp_rank else (
|
||||
1 - self.dist.pp_size)
|
||||
executed_microbatch_id = (microbatch_id +
|
||||
offset) % self.num_micro_batches
|
||||
executed_batch = self.micro_batches[executed_microbatch_id]
|
||||
if executed_batch is not None:
|
||||
self.executed_batch_queue.put(executed_batch)
|
||||
self.unhandled_batch_counter += 1
|
||||
self.micro_batches[executed_microbatch_id] = None
|
||||
|
||||
def fetch_executed_batches() -> list[BatchStatePP]:
|
||||
executed_batches = []
|
||||
if self.pp_async_broadcast_sample_state:
|
||||
# Wait for at least one batch to finish if no new request is available.
|
||||
must_get = not can_queue
|
||||
else:
|
||||
must_get = True
|
||||
while not self.executed_batch_response_queue.empty() or (
|
||||
must_get and self.unhandled_batch_counter > 0):
|
||||
with nvtx_range("get_executed_batch"):
|
||||
executed_batches.append(
|
||||
self.executed_batch_response_queue.get())
|
||||
must_get = False
|
||||
return executed_batches
|
||||
|
||||
def ring_broadcast_executed_batch_num(
|
||||
executed_batch_num: int) -> int:
|
||||
if self.dist.is_first_pp_rank and self.dist.tp_size * self.dist.cp_size > 1:
|
||||
with nvtx_range("tp_cp_broadcast_executed_batch_num"):
|
||||
executed_batch_num = self.dist.tp_cp_broadcast(
|
||||
executed_batch_num,
|
||||
root=0,
|
||||
)
|
||||
|
||||
# Send tokens to next pp rank (w.r.t model forward direction)
|
||||
# Second last rank does not need to since last rank has original decoded tokens
|
||||
if not self.dist.is_second_last_pp_rank:
|
||||
self.wait_on_pp_send_handles(prev_microbatch_id)
|
||||
with nvtx_range("send_sample_state"):
|
||||
self.send_handles[
|
||||
prev_microbatch_id] = self.dist.isend_object(
|
||||
sample_state.host,
|
||||
if not self.dist.is_first_pp_rank:
|
||||
with nvtx_range("recv_expected_batch_num"):
|
||||
executed_batch_num = self.dist.recv_object(
|
||||
src=self.dist.prev_pp_rank,
|
||||
tag=PPCommTag.EXECUTED_BATCH_NUM,
|
||||
)
|
||||
if not self.dist.is_last_pp_rank:
|
||||
self.wait_on_pp_send_handles(
|
||||
self.send_expected_batch_num_handles, microbatch_id)
|
||||
with nvtx_range("send_expected_batch_num"):
|
||||
self.send_expected_batch_num_handles[
|
||||
microbatch_id] = self.dist.isend_object(
|
||||
executed_batch_num,
|
||||
dest=self.dist.next_pp_rank,
|
||||
tag=tag)
|
||||
tag=PPCommTag.EXECUTED_BATCH_NUM,
|
||||
)
|
||||
return executed_batch_num
|
||||
|
||||
# Stage 3: Finalize previous batch that finished sample state communication
|
||||
# In last pp rank, stage 2 and 3 process different previous batches
|
||||
prev_microbatch_id = (microbatch_id +
|
||||
1) % self.num_micro_batches
|
||||
previous_batch = self.micro_batches[prev_microbatch_id]
|
||||
finished_requests = []
|
||||
if previous_batch is not None:
|
||||
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
|
||||
sample_state = previous_batch.sample_state
|
||||
sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs
|
||||
self._update_requests(previous_batch.sample_state)
|
||||
def handle_executed_batches(executed_batch_num: int):
|
||||
if self.dist.rank != 0:
|
||||
dequeue_counter = 0
|
||||
while dequeue_counter < executed_batch_num:
|
||||
with nvtx_range("get_executed_batch"):
|
||||
executed_batch = self.executed_batch_response_queue.get(
|
||||
)
|
||||
self._handle_executed_batch(executed_batch)
|
||||
dequeue_counter += 1
|
||||
else:
|
||||
for executed_batch in executed_batches:
|
||||
self._handle_executed_batch(executed_batch)
|
||||
self.unhandled_batch_counter -= executed_batch_num
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
self._send_kv_async(
|
||||
previous_batch.finished_ctx_reqs)
|
||||
self._handle_canceled_requests()
|
||||
executed_batch_num = 0
|
||||
|
||||
self._handle_logits_communication(
|
||||
previous_batch, prev_microbatch_id)
|
||||
# Stage 3.1: The first rank determines the number of executed batches.
|
||||
if self.dist.rank == 0:
|
||||
executed_batches = fetch_executed_batches()
|
||||
executed_batch_num = len(executed_batches)
|
||||
|
||||
finished_requests = self._handle_responses()
|
||||
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
|
||||
attn_metadata = getattr(self.model_engine,
|
||||
'attn_metadata', None)
|
||||
kv_cache_dtype_byte_size = getattr(
|
||||
self.model_engine, 'kv_cache_dtype_byte_size', None)
|
||||
self.resource_manager.update_resources(
|
||||
previous_scheduled_batch, attn_metadata,
|
||||
kv_cache_dtype_byte_size)
|
||||
# Stage 3.2: Broadcast the number of executed batches to other ranks.
|
||||
executed_batch_num = ring_broadcast_executed_batch_num(
|
||||
executed_batch_num)
|
||||
|
||||
self._remove_inflight_ids(previous_batch)
|
||||
# Stage 3.3: Handle executed batches.
|
||||
handle_executed_batches(executed_batch_num)
|
||||
|
||||
self.wait_on_pp_send_handles(prev_microbatch_id)
|
||||
self.micro_batches[prev_microbatch_id] = None
|
||||
|
||||
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
|
||||
):
|
||||
self._check_kv_transfer_timeout()
|
||||
|
||||
if self._disagg_pp_termination_handler is not None:
|
||||
self._disagg_pp_termination_handler.terminate_pending_requests(
|
||||
)
|
||||
|
||||
# march forward in microbatch slots
|
||||
# Stage 4: March forward in microbatch slots
|
||||
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
|
||||
|
||||
if self.enable_iter_perf_stats and previous_batch is not None:
|
||||
sample_state = previous_batch.sample_state
|
||||
sample_state.scheduled_requests.context_requests = previous_batch.scheduled_ctx_reqs
|
||||
self._process_iter_stats(finished_requests,
|
||||
self.active_requests,
|
||||
previous_batch, microbatch_id)
|
||||
|
||||
self.iter_counter += 1
|
||||
|
||||
# Stage 5: Handle remaining executed batches in the queue.
|
||||
while self.unhandled_batch_counter > 0:
|
||||
with nvtx_range("get_executed_batch"):
|
||||
executed_batch = self.executed_batch_response_queue.get()
|
||||
self._handle_executed_batch(executed_batch)
|
||||
self.unhandled_batch_counter -= 1
|
||||
|
||||
def _broadcast_sample_state_loop(self):
|
||||
logger.debug(
|
||||
f"Starting broadcast sample state loop for pp_rank {self.dist.pp_rank}"
|
||||
)
|
||||
torch.cuda.set_device(self.device_id)
|
||||
# ensure the context is created, otherwise, some MPI calls will fail.
|
||||
CUASSERT(cudart.cudaSetDevice(self.device_id))
|
||||
# Acquiring pkl5.Intracomm's send/recv locks from both executor loop thread
|
||||
# and this thread will cause perf drop and even deadlock.
|
||||
# We create new MPI comm to avoid these issues.
|
||||
logger.info(
|
||||
"Create new MPI comm for broadcast sample state thread to avoid deadlock."
|
||||
)
|
||||
new_mpi_comm = mpi_comm().Dup()
|
||||
set_thread_local_mpi_comm(new_mpi_comm)
|
||||
while True:
|
||||
executed_batch = self.executed_batch_queue.get()
|
||||
if executed_batch is None:
|
||||
break
|
||||
self._ring_broadcast_sample_state(executed_batch)
|
||||
set_thread_local_mpi_comm(None)
|
||||
new_mpi_comm.Free()
|
||||
|
||||
def _ring_broadcast_sample_state(
|
||||
self,
|
||||
executed_batch: Optional[BatchStatePP],
|
||||
) -> None:
|
||||
if executed_batch is None:
|
||||
return
|
||||
|
||||
tag = PPCommTag.SAMPLE_STATE
|
||||
microbatch_id = executed_batch.microbatch_id
|
||||
sample_state = executed_batch.sample_state
|
||||
requests = sample_state.scheduled_requests.all_requests()
|
||||
|
||||
if not self.dist.is_last_pp_rank:
|
||||
# Receive tokens from previous pp rank (w.r.t model forward direction)
|
||||
with nvtx_range("recv_sample_state"):
|
||||
sample_state.host, py_result_diffs = self.dist.recv_object(
|
||||
src=self.dist.prev_pp_rank,
|
||||
tag=tag,
|
||||
)
|
||||
|
||||
for request, py_result_diff in zip(requests, py_result_diffs):
|
||||
request.py_result.apply_diff(py_result_diff)
|
||||
|
||||
self.executed_batch_response_queue.put(executed_batch)
|
||||
|
||||
# Send tokens to next pp rank (w.r.t model forward direction)
|
||||
# Second last rank does not need to since last rank has original decoded tokens
|
||||
if not self.dist.is_second_last_pp_rank:
|
||||
py_result_diffs = []
|
||||
for request in requests:
|
||||
diff = request.py_result.get_diff()
|
||||
py_result_diffs.append(diff)
|
||||
request.py_result.reset_diff()
|
||||
self.wait_on_pp_send_handles(self.send_handles, microbatch_id)
|
||||
with nvtx_range("send_sample_state"):
|
||||
self.send_handles[microbatch_id] = self.dist.isend_object(
|
||||
(sample_state.host, py_result_diffs),
|
||||
dest=self.dist.next_pp_rank,
|
||||
tag=tag,
|
||||
)
|
||||
|
||||
def _handle_executed_batch(self, executed_batch: Optional[BatchStatePP]):
|
||||
finished_requests = []
|
||||
if executed_batch is not None:
|
||||
with torch.cuda.nvtx.range("_handle_executed_batch_pp"):
|
||||
sample_state = executed_batch.sample_state
|
||||
sample_state.scheduled_requests.context_requests = executed_batch.finished_ctx_reqs
|
||||
self._update_requests(executed_batch.sample_state)
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
self._send_kv_async(executed_batch.finished_ctx_reqs)
|
||||
self._handle_canceled_requests()
|
||||
|
||||
finished_requests = self._handle_responses()
|
||||
previous_scheduled_batch = executed_batch.sample_state.scheduled_requests
|
||||
attn_metadata = getattr(self.model_engine, 'attn_metadata',
|
||||
None)
|
||||
kv_cache_dtype_byte_size = getattr(self.model_engine,
|
||||
'kv_cache_dtype_byte_size',
|
||||
None)
|
||||
self.resource_manager.update_resources(
|
||||
previous_scheduled_batch, attn_metadata,
|
||||
kv_cache_dtype_byte_size)
|
||||
|
||||
self._remove_inflight_ids(executed_batch)
|
||||
|
||||
if self.kv_cache_transceiver and self.async_transfer_manager.has_any_inflight_requests(
|
||||
):
|
||||
self._check_kv_transfer_timeout()
|
||||
|
||||
if self._disagg_pp_termination_handler is not None:
|
||||
self._disagg_pp_termination_handler.terminate_pending_requests()
|
||||
|
||||
if self.enable_iter_perf_stats and executed_batch is not None:
|
||||
sample_state = executed_batch.sample_state
|
||||
sample_state.scheduled_requests.context_requests = executed_batch.scheduled_ctx_reqs
|
||||
self._process_iter_stats(
|
||||
finished_requests,
|
||||
self.active_requests,
|
||||
executed_batch,
|
||||
executed_batch.microbatch_id % self.dist.pp_size,
|
||||
)
|
||||
|
||||
@nvtx_range("wait_on_pp_send_handles")
|
||||
def wait_on_pp_send_handles(self, microbatch_id):
|
||||
if self.send_handles[microbatch_id] is not None:
|
||||
self.send_handles[microbatch_id].wait()
|
||||
self.send_handles[microbatch_id] = None
|
||||
def wait_on_pp_send_handles(self, send_handles, microbatch_id):
|
||||
if send_handles[microbatch_id] is not None:
|
||||
send_handles[microbatch_id].wait()
|
||||
send_handles[microbatch_id] = None
|
||||
|
||||
def _can_queue(self, scheduled_batch):
|
||||
|
||||
@ -3022,41 +3202,6 @@ class PyExecutor:
|
||||
self._terminate_request(request)
|
||||
return requests_to_terminate
|
||||
|
||||
def _handle_logits_communication(self, previous_batch, prev_microbatch_id):
|
||||
"""Handle logits communication between pipeline parallel ranks.
|
||||
|
||||
If logits were requested, the last PP rank sends to the first PP rank (who sends responses)
|
||||
the logits of the requests that have finished.
|
||||
|
||||
Args:
|
||||
previous_batch: The previous batch state
|
||||
prev_microbatch_id: The microbatch ID for the previous batch
|
||||
"""
|
||||
# NOTE: If the rank processing the logits ever becomes the same as
|
||||
# the rank sending the responses, this code can be removed.
|
||||
finished_reqs = [
|
||||
r for r in
|
||||
previous_batch.sample_state.scheduled_requests.all_requests()
|
||||
if r.state == LlmRequestState.GENERATION_COMPLETE and (
|
||||
r.py_return_context_logits or r.py_return_generation_logits
|
||||
or r.py_additional_outputs is not None)
|
||||
]
|
||||
if self.dist.is_first_pp_rank and len(finished_reqs):
|
||||
finished_reqs_py_results = [r.py_result for r in finished_reqs]
|
||||
finished_reqs_py_results = self.dist.recv_object(
|
||||
src=self.dist.prev_pp_rank,
|
||||
tag=prev_microbatch_id,
|
||||
)
|
||||
for req, py_result in zip(finished_reqs, finished_reqs_py_results):
|
||||
req.py_result = py_result
|
||||
|
||||
elif self.dist.is_last_pp_rank and len(finished_reqs):
|
||||
self.wait_on_pp_send_handles(prev_microbatch_id)
|
||||
self.send_handles[prev_microbatch_id] = self.dist.isend_object(
|
||||
[r.py_result for r in finished_reqs],
|
||||
dest=self.dist.next_pp_rank,
|
||||
tag=prev_microbatch_id)
|
||||
|
||||
def _await_any_response(self,
|
||||
timeout: Optional[float] = None
|
||||
) -> List[LlmResponse]:
|
||||
@ -3198,7 +3343,7 @@ class DisaggPPTerminationHandler:
|
||||
self._pending_termination = {}
|
||||
self._terminating_iteration = 0
|
||||
self._send_handle = None
|
||||
self._comm_tag = TERMINATION_COMM_TAG_BASE
|
||||
self._comm_tag = PPCommTag.TERMINATION
|
||||
|
||||
def terminate(self, request: LlmRequest):
|
||||
self._pending_termination[request.py_request_id] = request
|
||||
|
||||
@ -673,7 +673,8 @@ class RequestBroadcaster:
|
||||
# Broadcast within first PP stage before send/recv chain to other PP stages.
|
||||
# This needs to cover both TP and CP ranks within the first PP stage.
|
||||
if self.dist.is_first_pp_rank:
|
||||
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
|
||||
with nvtx_range("tp_broadcast_requests"):
|
||||
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
|
||||
|
||||
# Tag for communication
|
||||
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
|
||||
|
||||
@ -8,7 +8,8 @@ from typing import Dict, List
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
|
||||
from tensorrt_llm._utils import (TensorWrapper, convert_to_torch_tensor,
|
||||
torch_dtype_to_str)
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.math_utils import ceil_div, pad_up
|
||||
from tensorrt_llm.quantization.utils import fp4_utils
|
||||
@ -416,6 +417,29 @@ def relu2(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.square(F.relu(x))
|
||||
|
||||
|
||||
def tensor_to_str(x: torch.Tensor, num_elements: int = 10) -> str:
|
||||
# Pass num_elements=-1 will print the whole tensor
|
||||
if num_elements < 0:
|
||||
num_elements = torch.numel(x)
|
||||
if x.dtype in (torch.int32, torch.int64):
|
||||
float_x = x.to(dtype=float)
|
||||
else:
|
||||
float_x = x
|
||||
return ("Tensor("
|
||||
f"shape={tuple(x.shape)}, "
|
||||
f"dtype={torch_dtype_to_str(x.dtype)}, "
|
||||
f"device={x.device}, "
|
||||
f"stats=("
|
||||
f"abs_mean={float_x.abs().mean().item():.3f}, "
|
||||
f"mean={float_x.mean().item():.3f}, "
|
||||
f"std={float_x.std().item():.3f}, "
|
||||
f"max={x.max().item():.3f}, "
|
||||
f"min={x.min().item():.3f}"
|
||||
"), "
|
||||
f"values={x.flatten()[:num_elements].tolist()}"
|
||||
")")
|
||||
|
||||
|
||||
@maybe_compile
|
||||
def maybe_compiled_copy_(dst, src):
|
||||
dst.copy_(src)
|
||||
|
||||
@ -23,6 +23,7 @@ import socket
|
||||
import struct
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import trace
|
||||
import traceback
|
||||
import weakref
|
||||
@ -509,7 +510,17 @@ def set_mpi_comm(new_comm):
|
||||
comm = new_comm
|
||||
|
||||
|
||||
thread_local_comm = threading.local()
|
||||
|
||||
|
||||
def set_thread_local_mpi_comm(new_comm):
|
||||
thread_local_comm.value = new_comm
|
||||
|
||||
|
||||
def mpi_comm():
|
||||
if hasattr(thread_local_comm,
|
||||
"value") and thread_local_comm.value is not None:
|
||||
return thread_local_comm.value
|
||||
return comm
|
||||
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
from mpi4py.MPI import COMM_WORLD, Comm
|
||||
from mpi4py.util import pkl5
|
||||
|
||||
from .._utils import global_mpi_rank, global_mpi_size
|
||||
|
||||
@ -327,7 +328,7 @@ def split_world_comm(
|
||||
f"global_rank: {global_rank}, instance_idx: {instance_idx}, sub_rank: {sub_rank}, is_leader: {is_leader}"
|
||||
)
|
||||
|
||||
return is_leader, instance_idx, sub_comm
|
||||
return is_leader, instance_idx, pkl5.Intracomm(sub_comm)
|
||||
|
||||
|
||||
def parse_metadata_server_config_file(
|
||||
|
||||
@ -170,8 +170,14 @@ class MpiPoolSession(MpiSession):
|
||||
def _start_mpi_pool(self):
|
||||
assert not self.mpi_pool, 'MPI session already started'
|
||||
|
||||
env = {
|
||||
key: value
|
||||
for key, value in os.environ.items()
|
||||
if key.startswith("TRTLLM") or key.startswith("TLLM")
|
||||
}
|
||||
self.mpi_pool = MPIPoolExecutor(max_workers=self.n_workers,
|
||||
path=sys.path)
|
||||
path=sys.path,
|
||||
env=env)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown_abort()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import datetime
|
||||
import gc
|
||||
import json
|
||||
@ -57,7 +58,8 @@ from llmapi.lora_test_utils import (
|
||||
check_llama_7b_multi_lora_from_request_test_harness,
|
||||
check_llama_7b_multi_unique_lora_adapters_from_request)
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_pre_hopper, skip_single_gpu
|
||||
from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb, skip_pre_hopper, skip_single_gpu, altered_env
|
||||
|
||||
# isort: on
|
||||
|
||||
# The unittests are based on the tiny-llama, which is fast to build and run.
|
||||
@ -2171,11 +2173,16 @@ def llm_get_stats_test_harness(tp_size: int = 1,
|
||||
llm_args_extra["fast_build"] = True
|
||||
LLM_CLASS = LLM
|
||||
|
||||
with LLM_CLASS(model=llama_model_path,
|
||||
kv_cache_config=global_kvcache_config,
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
**llm_args_extra) as llm:
|
||||
# Since we need to check pp's internal states, we disable the async broadcast
|
||||
# to get a deterministic behavior.
|
||||
env_ctx = altered_env(TLLM_PP_ASYNC_BROADCAST_SAMPLE_STATE="0") \
|
||||
if pp_size > 1 else contextlib.nullcontext()
|
||||
|
||||
with env_ctx, LLM_CLASS(model=llama_model_path,
|
||||
kv_cache_config=global_kvcache_config,
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size,
|
||||
**llm_args_extra) as llm:
|
||||
|
||||
max_tokens = 5
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import faulthandler
|
||||
import math
|
||||
import os
|
||||
@ -391,6 +392,23 @@ def run_session(session: Session,
|
||||
return outputs
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def altered_env(**kwargs):
|
||||
old = {}
|
||||
for k, v in kwargs.items():
|
||||
if k in os.environ:
|
||||
old[k] = os.environ[k]
|
||||
os.environ[k] = v
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for k in kwargs:
|
||||
if k not in old:
|
||||
os.environ.pop(k)
|
||||
else:
|
||||
os.environ[k] = old[k]
|
||||
|
||||
|
||||
def similarity_score(a, b):
|
||||
"similar compare a and b "
|
||||
return SequenceMatcher(None, a, b).ratio()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user