mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[TRTLLM-10666][chore] Refactor request fetching logic for better separation of concerns (#10988)
Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com> Signed-off-by: Lance Liao <108499334+lancelly@users.noreply.github.com> Signed-off-by: Liao Lanyu <108499334+lancelly@users.noreply.github.com> Co-authored-by: Lanyu Liao <lancelly@users.noreply.github.com>
This commit is contained in:
parent
b00e8338ec
commit
fef0e4b17d
@ -1,24 +1,16 @@
|
||||
import dataclasses
|
||||
import datetime
|
||||
import heapq
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections import deque, namedtuple
|
||||
from itertools import repeat
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import mpi_disabled, nvtx_range
|
||||
from tensorrt_llm.llmapi.disagg_utils import get_local_request_id
|
||||
from tensorrt_llm.mapping import CpType
|
||||
|
||||
from ..distributed import Distributed
|
||||
from .hang_detector import HangDetector
|
||||
from .llm_request import (ExecutorRequest, LlmRequest,
|
||||
executor_request_to_llm_request)
|
||||
from .llm_request import ExecutorRequest
|
||||
from .request_utils import get_num_child_requests
|
||||
|
||||
SHUTDOWN_REQUEST_ID = -1
|
||||
CONTROL_REQUEST_ID = -2
|
||||
@ -48,165 +40,24 @@ class RequestQueueItem:
|
||||
|
||||
|
||||
class ExecutorRequestQueue:
|
||||
"""Handles fetching and processing of new requests from the request queue."""
|
||||
"""Handles basic queue operations for executor requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist: Distributed,
|
||||
enable_attention_dp: bool,
|
||||
max_batch_size: int,
|
||||
max_beam_width: int,
|
||||
max_num_active_requests: int,
|
||||
enable_iter_perf_stats: bool,
|
||||
batch_wait_timeout_ms: float,
|
||||
hang_detector: Optional[HangDetector] = None,
|
||||
):
|
||||
self.dist = dist
|
||||
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
|
||||
self.waiting_queue: deque[RequestQueueItem] = deque()
|
||||
self.canceled_req_ids = []
|
||||
self.enable_attention_dp = enable_attention_dp
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_beam_width = max_beam_width
|
||||
self.max_num_active_requests = max_num_active_requests
|
||||
self.enqueue_lock = threading.Lock()
|
||||
self.next_request_id = max_batch_size
|
||||
self.enable_iter_perf_stats = enable_iter_perf_stats
|
||||
self.start_times = {}
|
||||
self.active = True
|
||||
self.batch_wait_timeout_ms = batch_wait_timeout_ms
|
||||
self.send_requests_handler = None
|
||||
self.hang_detector = hang_detector or HangDetector()
|
||||
|
||||
# State tracking
|
||||
self.num_fetch_requests = 0
|
||||
self.num_fetch_requests_cur_rank = 0
|
||||
self.expected_num_active_requests = 0
|
||||
self.new_active_requests_queue_latency_ms = 0
|
||||
self.is_shutdown = False
|
||||
self.should_exclude_last_generation_logits = False
|
||||
self.control_requests: List[RequestQueueItem] = []
|
||||
self.request_accumulated: List[RequestQueueItem] = []
|
||||
|
||||
self._disable_mpi = mpi_disabled()
|
||||
|
||||
def _get_from_request_queue(
|
||||
self,
|
||||
timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]:
|
||||
|
||||
items = []
|
||||
timeout_secs = timeout.total_seconds() if timeout is not None else None
|
||||
|
||||
try:
|
||||
if self.request_queue.empty() and (timeout_secs is None
|
||||
or timeout_secs > 0):
|
||||
# if queue is empty and want to wait, wait
|
||||
items.append(self.request_queue.get(timeout=timeout_secs))
|
||||
else:
|
||||
# if not empty or don't want to wait, just return all items in queue
|
||||
while True:
|
||||
queue_item = self.request_queue.get_nowait()
|
||||
items.append(queue_item)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
if self.batch_wait_timeout_ms == 0:
|
||||
return items
|
||||
|
||||
if len(items) >= self.max_batch_size:
|
||||
return items
|
||||
|
||||
deadline = time.monotonic() + self.batch_wait_timeout_ms / 1000.0
|
||||
while len(items) < self.max_batch_size:
|
||||
remaining_timeout = deadline - time.monotonic()
|
||||
|
||||
if remaining_timeout <= 0:
|
||||
break
|
||||
|
||||
try:
|
||||
item = self.request_queue.get(timeout=remaining_timeout)
|
||||
items.append(item)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def _get_num_child_requests(request: ExecutorRequest) -> int:
|
||||
sampling_config = request.sampling_config
|
||||
return 0 if sampling_config.beam_width > 1 else (
|
||||
sampling_config.num_return_sequences or 1) - 1
|
||||
|
||||
def _get_from_waiting_queue(
|
||||
self,
|
||||
waiting_queue: deque[RequestQueueItem],
|
||||
max_req_count: int,
|
||||
enable_attention_dp: bool,
|
||||
all_ranks_num_active_requests: Optional[List[int]] = None,
|
||||
) -> List[RequestQueueItem]:
|
||||
"""
|
||||
Args:
|
||||
waiting_queue: The queue to pop items from.
|
||||
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
|
||||
enable_attention_dp: Whether to enable attention DP scheduling.
|
||||
all_ranks_num_active_requests: Number of active requests for each rank.
|
||||
Returns:
|
||||
List of requests that can be processed.
|
||||
"""
|
||||
|
||||
if max_req_count <= 0:
|
||||
return []
|
||||
|
||||
req_count = 0
|
||||
items = []
|
||||
pending_requests = []
|
||||
|
||||
# Track the request with strict requirements
|
||||
scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy(
|
||||
) if enable_attention_dp else None
|
||||
while req_count < max_req_count and waiting_queue:
|
||||
req_item = waiting_queue[0]
|
||||
num_children = len(
|
||||
req_item.child_req_ids) if req_item.child_req_ids else 0
|
||||
if (req_count + 1 + num_children) > max_req_count:
|
||||
break
|
||||
req_item = waiting_queue.popleft()
|
||||
can_process = self._can_process_attention_dp_request(
|
||||
req_item, scheduling_all_ranks_num_active_requests
|
||||
) if enable_attention_dp else True
|
||||
|
||||
if can_process:
|
||||
items.append(req_item)
|
||||
req_count += 1 + num_children
|
||||
else:
|
||||
pending_requests.append(req_item)
|
||||
|
||||
# Put the pending requests back to the waiting queue
|
||||
# All ranks should have the same waiting queue
|
||||
waiting_queue.extendleft(reversed(pending_requests))
|
||||
|
||||
return items
|
||||
|
||||
def _can_process_attention_dp_request(
|
||||
self, req_item: RequestQueueItem,
|
||||
all_ranks_num_active_requests: List[int]) -> bool:
|
||||
"""Return True if the request can be processed immediately, else False."""
|
||||
|
||||
scheduling_params = getattr(req_item.request, 'py_scheduling_params',
|
||||
None)
|
||||
if scheduling_params is None:
|
||||
return True
|
||||
|
||||
target_dp_rank = scheduling_params.attention_dp_rank
|
||||
if target_dp_rank is None or scheduling_params.attention_dp_relax:
|
||||
return True
|
||||
|
||||
if all_ranks_num_active_requests[
|
||||
target_dp_rank] < self.max_num_active_requests:
|
||||
all_ranks_num_active_requests[target_dp_rank] += 1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_request_id(self, request: Optional[ExecutorRequest] = None):
|
||||
# if request has a disagg_request_id, use it as request id so that
|
||||
@ -223,7 +74,7 @@ class ExecutorRequestQueue:
|
||||
self, request: ExecutorRequest) -> List[int] | None:
|
||||
""" Generate child request IDs if needed. """
|
||||
child_req_ids = None
|
||||
num_children = self._get_num_child_requests(request)
|
||||
num_children = get_num_child_requests(request)
|
||||
if num_children > 0:
|
||||
child_req_ids = []
|
||||
for _ in range(num_children):
|
||||
@ -288,602 +139,53 @@ class ExecutorRequestQueue:
|
||||
with self.enqueue_lock:
|
||||
return self.active and self.dist.rank == 0
|
||||
|
||||
def _fetch_and_process_requests(
|
||||
self,
|
||||
total_num_active_requests: int,
|
||||
total_max_num_active_requests: int,
|
||||
enable_attention_dp: bool,
|
||||
all_ranks_num_active_requests: Optional[List[int]] = None
|
||||
) -> List[RequestQueueItem]:
|
||||
"""Common logic for fetching and processing requests from the queue."""
|
||||
# Block new request processing while control requests are pending.
|
||||
# Control requests must be handled exclusively to ensure proper synchronization.
|
||||
if len(self.control_requests) != 0:
|
||||
return []
|
||||
# Calculate timeout
|
||||
idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0
|
||||
if idle:
|
||||
# In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0
|
||||
# reaches the broadcast path regularly to prevent trtllm-serve timeout when idle.
|
||||
timeout = datetime.timedelta(
|
||||
seconds=1200) if self._disable_mpi else None
|
||||
else:
|
||||
timeout = datetime.timedelta(0)
|
||||
|
||||
# Fetch requests from rank 0
|
||||
new_requests = []
|
||||
if self.dist.rank == 0:
|
||||
# Process accumulated requests that were queued during control request handling.
|
||||
if len(self.request_accumulated) != 0:
|
||||
new_requests.extend(self.request_accumulated)
|
||||
self.request_accumulated.clear()
|
||||
# Reset timeout to 0 to avoid hanging when no new requests are available
|
||||
timeout = datetime.timedelta(0)
|
||||
with self.hang_detector.pause():
|
||||
new_requests.extend(self._get_from_request_queue(timeout))
|
||||
|
||||
# Broadcast requests and handle Python objects
|
||||
new_requests, py_request_objects = self._handle_request_broadcasting(
|
||||
new_requests)
|
||||
|
||||
# Validate and filter requests
|
||||
new_requests = self._handle_special_queue_items(new_requests)
|
||||
|
||||
# Attach Python objects to requests
|
||||
if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp
|
||||
or self.dist.cp_size
|
||||
> 1) and self.dist.rank > 0:
|
||||
self._attach_py_objects_to_requests(new_requests,
|
||||
py_request_objects)
|
||||
self.waiting_queue.extend(new_requests)
|
||||
|
||||
new_requests = self._get_from_waiting_queue(
|
||||
self.waiting_queue,
|
||||
total_max_num_active_requests - total_num_active_requests,
|
||||
enable_attention_dp, all_ranks_num_active_requests)
|
||||
|
||||
# Update performance metrics
|
||||
if self.enable_iter_perf_stats and self.dist.rank == 0:
|
||||
self._update_new_active_requests_queue_latency(new_requests)
|
||||
|
||||
return new_requests
|
||||
|
||||
@nvtx_range("_fetch_new_requests")
|
||||
def fetch_new_requests(
|
||||
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
|
||||
|
||||
if self.enable_attention_dp:
|
||||
return self._fetch_new_requests_attention_dp(activate_requests)
|
||||
else:
|
||||
return self._fetch_new_requests_attention_tp(len(activate_requests))
|
||||
|
||||
def _fetch_new_requests_attention_tp(
|
||||
self, num_active_requests: int) -> List[LlmRequest]:
|
||||
"""Handle standard (non-attention DP) request fetching."""
|
||||
total_num_active_requests = num_active_requests
|
||||
total_max_num_active_requests = self.max_num_active_requests
|
||||
|
||||
# fetch and process requests into waiting queue
|
||||
new_requests = self._fetch_and_process_requests(
|
||||
total_num_active_requests,
|
||||
total_max_num_active_requests,
|
||||
enable_attention_dp=False)
|
||||
|
||||
# Merge requests and add to active list
|
||||
merged_requests = self._merge_requests(new_requests)
|
||||
return merged_requests
|
||||
|
||||
def _fetch_new_requests_attention_dp(
|
||||
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
|
||||
"""Handle attention DP request fetching with load balancing."""
|
||||
# Get active request counts across all ranks.
|
||||
all_ranks_num_active_requests = []
|
||||
all_ranks_num_active_tokens = []
|
||||
|
||||
if self.dist.has_cp_helix:
|
||||
num_active_tokens = sum(
|
||||
[req.total_input_len_cp for req in activate_requests])
|
||||
else:
|
||||
num_active_tokens = sum(
|
||||
[req.py_orig_prompt_len for req in activate_requests])
|
||||
|
||||
# Note: We use tp_allgather even for CP assuming that all CP ranks with the
|
||||
# same dp_rank have the same num_active_tokens and num_active_requests.
|
||||
responses_list = self.dist.tp_allgather(
|
||||
[len(activate_requests), num_active_tokens])
|
||||
|
||||
for num_active_requests, num_active_tokens in responses_list:
|
||||
all_ranks_num_active_requests.append(num_active_requests)
|
||||
all_ranks_num_active_tokens.append(num_active_tokens)
|
||||
|
||||
total_num_active_requests = sum(all_ranks_num_active_requests)
|
||||
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
|
||||
|
||||
# fetch and process requests into waiting queue
|
||||
new_requests = self._fetch_and_process_requests(
|
||||
total_num_active_requests,
|
||||
total_max_num_active_requests,
|
||||
enable_attention_dp=True,
|
||||
all_ranks_num_active_requests=all_ranks_num_active_requests)
|
||||
|
||||
# Schedule attention dp requests
|
||||
all_ranks_new_requests = self._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]
|
||||
|
||||
# Update performance metrics
|
||||
if self.enable_iter_perf_stats and self.start_times:
|
||||
self._update_new_active_requests_queue_latency(
|
||||
new_requests_cur_rank)
|
||||
|
||||
# Update counters
|
||||
self.num_fetch_requests += len(new_requests)
|
||||
self.num_fetch_requests_cur_rank += len(new_requests_cur_rank)
|
||||
|
||||
# Merge requests and add to active list
|
||||
new_requests_cur_rank = self._merge_requests(new_requests_cur_rank)
|
||||
return new_requests_cur_rank
|
||||
|
||||
def _schedule_attention_dp_requests(
|
||||
self, new_requests: List[RequestQueueItem],
|
||||
all_ranks_num_active_requests: List[int],
|
||||
all_ranks_num_active_tokens: List[int]) -> List[RequestQueueItem]:
|
||||
"""Schedule attention dp requests."""
|
||||
|
||||
# Map from ranks to new requests
|
||||
all_ranks_new_requests = {
|
||||
tp_rank: []
|
||||
for tp_rank in range(self.dist.tp_size)
|
||||
}
|
||||
|
||||
# Prioritize the requests that are not in relax mode
|
||||
def get_relax_value(req_item):
|
||||
scheduling_params = getattr(req_item.request,
|
||||
'py_scheduling_params', None)
|
||||
if scheduling_params is None:
|
||||
return True
|
||||
return scheduling_params.attention_dp_relax
|
||||
|
||||
new_requests = sorted(new_requests, key=get_relax_value)
|
||||
|
||||
# Try to put the requests to the target dp rank until the max_num_active_requests is reached
|
||||
remaining_unscheduled = []
|
||||
for req_item in new_requests:
|
||||
scheduled = False
|
||||
scheduling_params = getattr(req_item.request,
|
||||
'py_scheduling_params', None)
|
||||
if scheduling_params is not None:
|
||||
target_dp_rank = scheduling_params.attention_dp_rank
|
||||
if target_dp_rank is not None and all_ranks_num_active_requests[
|
||||
target_dp_rank] < self.max_num_active_requests:
|
||||
all_ranks_num_active_requests[target_dp_rank] += 1
|
||||
scheduled = True
|
||||
all_ranks_new_requests[target_dp_rank].append(req_item)
|
||||
|
||||
if not scheduled:
|
||||
remaining_unscheduled.append(req_item)
|
||||
|
||||
# Balance the remaining unscheduled requests across ranks
|
||||
num_new_requests_all_ranks = len(remaining_unscheduled)
|
||||
total_num_active_requests = sum(all_ranks_num_active_requests)
|
||||
self.expected_num_active_requests = max(
|
||||
(total_num_active_requests + num_new_requests_all_ranks +
|
||||
self.dist.tp_size - 1) // self.dist.tp_size,
|
||||
max(all_ranks_num_active_requests),
|
||||
)
|
||||
|
||||
all_ranks_new_requests = self._balance_requests_across_ranks(
|
||||
remaining_unscheduled, all_ranks_new_requests,
|
||||
all_ranks_num_active_requests, all_ranks_num_active_tokens)
|
||||
|
||||
return all_ranks_new_requests
|
||||
|
||||
def _handle_request_broadcasting(self,
|
||||
new_requests: List[RequestQueueItem]):
|
||||
"""Handle broadcasting of requests and Python objects across ranks."""
|
||||
if self.dist.rank == 0:
|
||||
py_logits_post_processors = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_logits_post_processors")
|
||||
py_multimodal_data = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_multimodal_data")
|
||||
py_scheduling_params = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_scheduling_params")
|
||||
py_num_logprobs = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_num_logprobs")
|
||||
py_disaggregated_params = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_disaggregated_params")
|
||||
py_request_objects = tuple(
|
||||
filter(None, [
|
||||
py_logits_post_processors, py_multimodal_data,
|
||||
py_scheduling_params, py_num_logprobs,
|
||||
py_disaggregated_params
|
||||
]))
|
||||
else:
|
||||
py_request_objects = None
|
||||
|
||||
if self.dist.rank == 0:
|
||||
# Preserve original `new_requests` on rank 0
|
||||
_ = self._broadcast_new_requests(new_requests, py_request_objects)
|
||||
else:
|
||||
with self.hang_detector.pause():
|
||||
new_requests, py_request_objects = self._broadcast_new_requests(
|
||||
new_requests, py_request_objects)
|
||||
|
||||
return new_requests, py_request_objects
|
||||
|
||||
def _handle_special_queue_items(
|
||||
def get_from_request_queue(
|
||||
self,
|
||||
new_requests: List[RequestQueueItem]) -> List[RequestQueueItem]:
|
||||
"""Handle special signals."""
|
||||
accepted_new_requests = []
|
||||
for idx, req_item in enumerate(new_requests):
|
||||
if req_item.is_shutdown_request:
|
||||
self.is_shutdown = True
|
||||
timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]:
|
||||
"""Fetch requests from the queue with optional timeout.
|
||||
|
||||
Args:
|
||||
timeout: Optional timeout for waiting on queue.
|
||||
|
||||
Returns:
|
||||
List of RequestQueueItem fetched from the queue.
|
||||
"""
|
||||
items = []
|
||||
timeout_secs = timeout.total_seconds() if timeout is not None else None
|
||||
|
||||
try:
|
||||
if self.request_queue.empty() and (timeout_secs is None
|
||||
or timeout_secs > 0):
|
||||
# if queue is empty and want to wait, wait
|
||||
items.append(self.request_queue.get(timeout=timeout_secs))
|
||||
else:
|
||||
# if not empty or don't want to wait, just return all items in queue
|
||||
while True:
|
||||
queue_item = self.request_queue.get_nowait()
|
||||
items.append(queue_item)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
if self.batch_wait_timeout_ms == 0:
|
||||
return items
|
||||
|
||||
if len(items) >= self.max_batch_size:
|
||||
return items
|
||||
|
||||
deadline = time.monotonic() + self.batch_wait_timeout_ms / 1000.0
|
||||
while len(items) < self.max_batch_size:
|
||||
remaining_timeout = deadline - time.monotonic()
|
||||
|
||||
if remaining_timeout <= 0:
|
||||
break
|
||||
elif req_item.is_canceled_request:
|
||||
self.canceled_req_ids.append(req_item.id)
|
||||
elif req_item.is_control_request:
|
||||
self.control_requests.append(req_item)
|
||||
if self.dist.rank == 0:
|
||||
self.request_accumulated.extend(new_requests[idx + 1:])
|
||||
|
||||
try:
|
||||
item = self.request_queue.get(timeout=remaining_timeout)
|
||||
items.append(item)
|
||||
except queue.Empty:
|
||||
break
|
||||
else:
|
||||
accepted_new_requests.append(req_item)
|
||||
|
||||
return accepted_new_requests
|
||||
|
||||
def _balance_requests_across_ranks(
|
||||
self, new_requests: List[RequestQueueItem],
|
||||
all_ranks_new_requests: Dict[int, List[RequestQueueItem]],
|
||||
all_ranks_num_active_requests: List[int],
|
||||
all_ranks_num_active_tokens: List[int]) -> List[RequestQueueItem]:
|
||||
"""Balance requests across ranks for attention DP."""
|
||||
if new_requests:
|
||||
# Balance context tokens across ranks using heap
|
||||
HeapVal = namedtuple(
|
||||
'HeapVal',
|
||||
['num_tokens', 'num_requests', 'rank', 'request_list'])
|
||||
|
||||
all_ranks_new_requests_heap = [
|
||||
HeapVal(all_ranks_num_active_tokens[tp_rank], val, tp_rank, [])
|
||||
for tp_rank, val in enumerate(all_ranks_num_active_requests)
|
||||
]
|
||||
|
||||
all_ranks_new_requests_heap = [
|
||||
val for val in all_ranks_new_requests_heap
|
||||
if val.num_requests < self.expected_num_active_requests
|
||||
]
|
||||
|
||||
all_ranks_new_scheduled_requests = {
|
||||
val.rank: val.request_list
|
||||
for val in all_ranks_new_requests_heap
|
||||
}
|
||||
|
||||
heapq.heapify(all_ranks_new_requests_heap)
|
||||
|
||||
# Sort by token count (descending) for better load balancing
|
||||
new_requests = sorted(
|
||||
new_requests,
|
||||
key=lambda x: len(getattr(x.request, 'input_token_ids', []))
|
||||
if x.request else 0,
|
||||
reverse=True)
|
||||
|
||||
# Distribute requests across ranks
|
||||
for req_item in new_requests:
|
||||
|
||||
val = heapq.heappop(all_ranks_new_requests_heap)
|
||||
token_count = len(
|
||||
getattr(req_item.request, 'input_token_ids',
|
||||
[])) if req_item.request else 0
|
||||
# Update the heap value with the new request
|
||||
val = val._replace(
|
||||
num_tokens=val.num_tokens + token_count,
|
||||
num_requests=val.num_requests + 1,
|
||||
)
|
||||
|
||||
val.request_list.append(req_item)
|
||||
# If rank still has room for new requests, push back into heap
|
||||
if val.num_requests < self.expected_num_active_requests:
|
||||
heapq.heappush(all_ranks_new_requests_heap, val)
|
||||
|
||||
# Extend all_ranks_new_requests with the new requests that have been scheduled
|
||||
for rank, reqs in all_ranks_new_scheduled_requests.items():
|
||||
all_ranks_new_requests[rank].extend(reqs)
|
||||
|
||||
return all_ranks_new_requests
|
||||
|
||||
def _collect_py_objects_from_requests(
|
||||
self, requests: List[RequestQueueItem],
|
||||
attribute_name: str) -> Optional[Tuple[str, Dict]]:
|
||||
"""Collect Python-only objects from requests."""
|
||||
req_id_to_obj = {}
|
||||
for item in requests:
|
||||
if not item.is_normal_request:
|
||||
continue
|
||||
if item.request:
|
||||
obj = getattr(item.request, attribute_name, None)
|
||||
if obj is not None:
|
||||
req_id_to_obj[item.id] = obj
|
||||
return None if not req_id_to_obj else (attribute_name, req_id_to_obj)
|
||||
|
||||
@nvtx_range("_broadcast_new_requests")
|
||||
def _broadcast_new_requests(
|
||||
self, new_requests: List[RequestQueueItem], py_request_objects
|
||||
) -> Tuple[List[RequestQueueItem], Optional[Dict]]:
|
||||
"""Broadcast new_requests and optional Python-only metadata across pipeline stages."""
|
||||
payloads = (new_requests, py_request_objects)
|
||||
|
||||
if not self.dist.has_pp:
|
||||
return self.dist.broadcast(payloads, root=0)
|
||||
|
||||
# Broadcast within first PP stage before send/recv chain to other PP stages.
|
||||
# This needs to cover both TP and CP ranks within the first PP stage.
|
||||
if self.dist.is_first_pp_rank:
|
||||
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
|
||||
|
||||
# Tag for communication
|
||||
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
|
||||
|
||||
# Send payloads
|
||||
if not self.dist.is_first_pp_rank:
|
||||
with nvtx_range("recv_requests_from_prev_pp"):
|
||||
payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag)
|
||||
|
||||
# isend new requests may cause deadlock, when CUDA_LAUNCH_BLOCKING=1 or PP microbatches can't overlap,
|
||||
# the deadlock will happen deterministicly:
|
||||
# 1. rank1 will wait on nccl.send(rank2), without invoking mpi.wait(isend-handle)
|
||||
# 2. rank2 will wait on mpi.recv(rank1) but never receive the new requests.
|
||||
# 3. rank1 will hang on nccl.send because rank2 will never reach nccl.recv(rank1).
|
||||
pp_send_func = self.dist.isend_object if os.environ.get(
|
||||
"TRTLLM_PP_REQ_SEND_ASYNC", "0") == "1" else self.dist.send_object
|
||||
|
||||
if not self.dist.is_last_pp_rank:
|
||||
if self.send_requests_handler is not None:
|
||||
with nvtx_range("wait_prev_send_requests_handler"):
|
||||
self.send_requests_handler.wait()
|
||||
with nvtx_range("send_requests_to_next_pp"):
|
||||
self.send_requests_handler = pp_send_func(
|
||||
payloads, self.dist.next_pp_rank, tag)
|
||||
|
||||
return payloads
|
||||
|
||||
def _attach_py_objects_to_requests(self, requests: List[RequestQueueItem],
|
||||
py_request_objects) -> None:
|
||||
"""Attach Python-only objects to each request."""
|
||||
for attr_name, req_obj_dict in py_request_objects:
|
||||
for item in requests:
|
||||
if item.request:
|
||||
py_obj = req_obj_dict.get(item.id)
|
||||
if py_obj is not None:
|
||||
setattr(item.request, attr_name, py_obj)
|
||||
|
||||
def _update_new_active_requests_queue_latency(
|
||||
self, new_requests: List[RequestQueueItem]):
|
||||
"""Update queue latency metrics for new requests."""
|
||||
now = time.time()
|
||||
for req_item in new_requests:
|
||||
if req_item.id in self.start_times:
|
||||
self.new_active_requests_queue_latency_ms += now - self.start_times.pop(
|
||||
req_item.id)
|
||||
if req_item.child_req_ids:
|
||||
for child_id in req_item.child_req_ids:
|
||||
self.new_active_requests_queue_latency_ms += now - self.start_times.pop(
|
||||
child_id)
|
||||
|
||||
# Note: Helix parallelism is a decode-only feature run with disaggregated serving. This function gets called on gen server
|
||||
# during initialization of a new request.
|
||||
def _merge_helix_requests(self, new_requests: list[RequestQueueItem],
|
||||
tokens_per_block: int):
|
||||
req_with_children = []
|
||||
num_cp_ranks = self.dist.cp_size
|
||||
curr_cp_rank = self.dist.cp_rank
|
||||
|
||||
# For each request, partition the input_token_ids into blocks and then partition blocks across CP ranks.
|
||||
# Currently, the partitioning is such that contiguous blocks are assigned to the same CP rank (as opposed
|
||||
# to round-robin).
|
||||
for req_item in new_requests:
|
||||
all_input_ids = torch.tensor(req_item.request.input_token_ids,
|
||||
dtype=torch.int64).unsqueeze(0)
|
||||
input_len = all_input_ids.shape[-1]
|
||||
|
||||
num_total_blocks = (input_len + tokens_per_block -
|
||||
1) // tokens_per_block
|
||||
if num_total_blocks < num_cp_ranks:
|
||||
raise ValueError(
|
||||
f"There aren't enough tokens to get at least one block per CP rank. num_total_blocks {num_total_blocks} < num_cp_ranks {num_cp_ranks}. Please use smaller tokens_per_block for KV cache or reduce the number of CP ranks."
|
||||
)
|
||||
|
||||
# Padding to ensure torch.stack used with torch.tensor_split works properly.
|
||||
padding_len = 0
|
||||
if input_len % tokens_per_block != 0:
|
||||
padding_len = tokens_per_block - (input_len % tokens_per_block)
|
||||
padding_ids = torch.zeros([1, padding_len], dtype=torch.int64)
|
||||
all_input_ids = torch.cat((all_input_ids, padding_ids), dim=-1)
|
||||
all_position_ids = torch.arange(0,
|
||||
input_len + padding_len,
|
||||
dtype=torch.int64).unsqueeze(0)
|
||||
|
||||
input_id_blocks_per_rank = torch.tensor_split(
|
||||
torch.stack(all_input_ids.split(tokens_per_block, dim=-1)),
|
||||
num_cp_ranks)
|
||||
position_id_blocks_per_rank = torch.tensor_split(
|
||||
torch.stack(all_position_ids.split(tokens_per_block, dim=-1)),
|
||||
num_cp_ranks)
|
||||
|
||||
# Get the input_ids and position_ids for this rank.
|
||||
input_ids_this_rank = input_id_blocks_per_rank[
|
||||
curr_cp_rank].flatten().tolist()
|
||||
position_ids_this_rank = position_id_blocks_per_rank[
|
||||
curr_cp_rank].flatten().tolist()
|
||||
|
||||
# Undo the padding. Only last rank's last block will be padded right now
|
||||
# given contiguous block assignment.
|
||||
if curr_cp_rank == num_cp_ranks - 1 and padding_len > 0:
|
||||
input_ids_this_rank = input_ids_this_rank[:-padding_len]
|
||||
position_ids_this_rank = position_ids_this_rank[:-padding_len]
|
||||
|
||||
req = executor_request_to_llm_request(
|
||||
req_id=req_item.id,
|
||||
executor_request=req_item.request,
|
||||
child_req_ids=req_item.child_req_ids,
|
||||
exclude_last_generation_logits=self.
|
||||
_should_exclude_last_generation_logits(),
|
||||
input_token_ids=input_ids_this_rank,
|
||||
position_ids=position_ids_this_rank,
|
||||
)
|
||||
req.total_input_len_cp = input_len
|
||||
req.seqlen_this_rank_cp = len(input_ids_this_rank)
|
||||
req_with_children.append(req)
|
||||
if req.child_requests:
|
||||
req_with_children.extend(req.child_requests)
|
||||
return req_with_children
|
||||
|
||||
@nvtx_range("_merge_requests")
|
||||
def _merge_requests(
|
||||
self, new_requests: list[RequestQueueItem]) -> List[LlmRequest]:
|
||||
cp_config = self.dist.cp_config
|
||||
if 'cp_type' in cp_config:
|
||||
cp_type = cp_config['cp_type']
|
||||
if cp_type == CpType.STAR:
|
||||
return self._merge_star_attention_requests(new_requests)
|
||||
elif cp_type == CpType.HELIX:
|
||||
return self._merge_helix_requests(
|
||||
new_requests,
|
||||
tokens_per_block=cp_config['tokens_per_block'])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Unsupported cp type {cp_type.name}.')
|
||||
|
||||
req_with_children = []
|
||||
for req_item in new_requests:
|
||||
req = executor_request_to_llm_request(
|
||||
req_item.id, req_item.request, req_item.child_req_ids,
|
||||
self._should_exclude_last_generation_logits())
|
||||
req_with_children.append(req)
|
||||
if req.child_requests:
|
||||
req_with_children.extend(req.child_requests)
|
||||
return req_with_children
|
||||
|
||||
def _merge_star_attention_requests(
|
||||
self, new_requests: list[RequestQueueItem]) -> List[LlmRequest]:
|
||||
result = []
|
||||
for req_item in new_requests:
|
||||
req_id, exe_req, query_token_ids = req_item.id, req_item.request, req_item.query
|
||||
ctx_len0 = len(exe_req.input_token_ids)
|
||||
ctx_blocks, position_blocks, last_block_padding_num = [
|
||||
exe_req.input_token_ids
|
||||
], [[i for i in range(ctx_len0)]], 0
|
||||
ctx_blocks, position_blocks, last_block_padding_num = self._partition_context(
|
||||
exe_req.input_token_ids)
|
||||
if self.dist.cp_rank == self.dist.cp_size - 1 and last_block_padding_num > 0:
|
||||
ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num]
|
||||
position_blocks[-1] = position_blocks[
|
||||
-1][:-last_block_padding_num]
|
||||
#if has query
|
||||
if query_token_ids:
|
||||
ctx_blocks.append(query_token_ids)
|
||||
position_blocks.append([
|
||||
i for i in range(ctx_len0, ctx_len0 + len(query_token_ids))
|
||||
])
|
||||
|
||||
# insert the dummy block to align the number of ctx iterations of each rank
|
||||
block_size = self.dist.cp_config['block_size']
|
||||
total_blocks = (ctx_len0 + block_size - 1) // block_size
|
||||
num_blocks_per_rank = (
|
||||
total_blocks + self.dist.cp_size -
|
||||
1) // self.dist.cp_size + 1 # 1 for query block
|
||||
if len(ctx_blocks) == num_blocks_per_rank:
|
||||
ctx_blocks.insert(1, [])
|
||||
position_blocks.insert(1, [])
|
||||
elif len(ctx_blocks) == num_blocks_per_rank + 1:
|
||||
# anchor + ctx_blocks + qry_block
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f'rank = {self.dist.cp_rank}, len(ctx_blocks) = {len(ctx_blocks) }, num_blocks_per_rank = {num_blocks_per_rank}'
|
||||
)
|
||||
assert False, f'invalid context partition'
|
||||
|
||||
# fake data for scheduler
|
||||
ctx_blocks_list = [0] * (block_size +
|
||||
self.dist.cp_config['cp_anchor_size'])
|
||||
|
||||
req = executor_request_to_llm_request(
|
||||
req_id, exe_req, self._should_exclude_last_generation_logits(),
|
||||
ctx_blocks_list)
|
||||
req.gen_iters = 0
|
||||
req.ctx_iters = 0
|
||||
req.ctx_blocks = ctx_blocks
|
||||
req.ctx_position_blocks = position_blocks
|
||||
req.query_id = query_token_ids
|
||||
|
||||
result.append(req)
|
||||
|
||||
return result
|
||||
|
||||
def _partition_context(self, ctx_ids_list):
|
||||
ctx_ids = torch.tensor(ctx_ids_list).unsqueeze(0)
|
||||
ctx_len = ctx_ids.shape[-1]
|
||||
block_size = self.dist.cp_config['block_size']
|
||||
if block_size is None:
|
||||
block_size = ctx_len // self.dist.cp_size
|
||||
anchor_block_size = self.dist.cp_config['cp_anchor_size']
|
||||
if anchor_block_size is None:
|
||||
anchor_block_size = block_size
|
||||
|
||||
assert anchor_block_size <= block_size, f'cp_anchor_size {anchor_block_size} should be smaller than block_size {block_size}'
|
||||
padding = 0
|
||||
if ctx_len % block_size != 0:
|
||||
padding = block_size - (ctx_len % block_size)
|
||||
assert padding <= ctx_len, f'block size is too large for context, please set it smaller'
|
||||
ctx_ids = torch.cat(
|
||||
(ctx_ids, torch.zeros_like(ctx_ids)[:, :padding]), dim=-1)
|
||||
position_ids = torch.arange(0, ctx_ids.shape[-1]).unsqueeze(0)
|
||||
|
||||
ctx_ids_blocks = torch.tensor_split(
|
||||
torch.stack(ctx_ids.split(block_size, dim=-1)), self.dist.cp_size)
|
||||
position_ids_blocks = torch.tensor_split(
|
||||
torch.stack(position_ids.split(block_size, dim=-1)),
|
||||
self.dist.cp_size)
|
||||
if self.dist.cp_rank != 0:
|
||||
ctx_blocks, position_blocks = [
|
||||
ctx_ids_blocks[0][0].tolist()[0][:anchor_block_size]
|
||||
], [position_ids_blocks[0][0].tolist()[0][:anchor_block_size]]
|
||||
else:
|
||||
ctx_blocks, position_blocks = [], []
|
||||
|
||||
for idx in range(len(ctx_ids_blocks[self.dist.cp_rank])):
|
||||
ctx_block = ctx_ids_blocks[self.dist.cp_rank][idx]
|
||||
position_block = position_ids_blocks[self.dist.cp_rank][idx]
|
||||
ctx_blocks.append(ctx_block.tolist()[0])
|
||||
position_blocks.append(position_block.tolist()[0])
|
||||
return ctx_blocks, position_blocks, padding
|
||||
|
||||
def set_exclude_last_generation_logits(self,
|
||||
disable_overlap_scheduler: bool,
|
||||
pp_size: int) -> None:
|
||||
# When overlap scheduler is enabled then when starting to handle a new prompt,
|
||||
# sample_async is called twice before the first call to update_requests:
|
||||
# - 1st time as a context request that handles on the 1st generated token
|
||||
# - 2nd time as a generation request that handles on the 2nd generated token.
|
||||
# and only after these two calls the sampler's update_request method is called.
|
||||
# So in a sampler that works by the expected flow of handling the logits in
|
||||
# sample_async, every update_request doesn't handle the newest token, but one
|
||||
# before it. Since all these calls work on the same request object, then its
|
||||
# logits storage contains the logits of both the token update_requests should work
|
||||
# on, and also its next token. Thus, excluding the last generation logits from any
|
||||
# getter is required.
|
||||
self.should_exclude_last_generation_logits = not disable_overlap_scheduler and pp_size == 1
|
||||
|
||||
def _should_exclude_last_generation_logits(self) -> bool:
|
||||
return self.should_exclude_last_generation_logits
|
||||
|
||||
def get_new_active_requests_queue_latency(self) -> float:
|
||||
return self.new_active_requests_queue_latency_ms
|
||||
|
||||
def get_expected_num_active_requests(self) -> int:
|
||||
return self.expected_num_active_requests
|
||||
return items
|
||||
|
||||
def get_request_queue_size(self) -> int:
|
||||
return self.request_queue.qsize()
|
||||
@ -891,22 +193,22 @@ class ExecutorRequestQueue:
|
||||
def get_request_queue(self) -> queue.Queue[RequestQueueItem]:
|
||||
return self.request_queue
|
||||
|
||||
def get_waiting_queue(self) -> deque[RequestQueueItem]:
|
||||
return self.waiting_queue
|
||||
def calculate_queue_latency(self, request_items: List[RequestQueueItem],
|
||||
now: float) -> float:
|
||||
if not self.enable_iter_perf_stats:
|
||||
return 0.0
|
||||
|
||||
def update_waiting_queue(self):
|
||||
# Remove cancel request in the waiting queue
|
||||
self.waiting_queue = deque(req for req in self.waiting_queue
|
||||
if req.id not in self.canceled_req_ids)
|
||||
total_latency = 0.0
|
||||
|
||||
def get_waiting_queue_size(self) -> int:
|
||||
return len(self.waiting_queue)
|
||||
for req_item in request_items:
|
||||
# Handle parent request
|
||||
if req_item.id in self.start_times:
|
||||
total_latency += now - self.start_times.pop(req_item.id)
|
||||
|
||||
def get_canceled_req_ids_size(self) -> int:
|
||||
return len(self.canceled_req_ids)
|
||||
# Handle child requests
|
||||
if req_item.child_req_ids:
|
||||
for child_id in req_item.child_req_ids:
|
||||
if child_id in self.start_times:
|
||||
total_latency += now - self.start_times.pop(child_id)
|
||||
|
||||
def get_canceled_req_ids(self) -> List[int]:
|
||||
return self.canceled_req_ids
|
||||
|
||||
def clear_canceled_req_ids(self):
|
||||
self.canceled_req_ids.clear()
|
||||
return total_latency
|
||||
|
||||
@ -6,6 +6,7 @@ import pickle # nosec B403
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
@ -22,7 +23,7 @@ except ImportError:
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import (
|
||||
ResourceManagerType, request_context)
|
||||
from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
|
||||
nvtx_range, trace_func)
|
||||
mpi_disabled, nvtx_range, trace_func)
|
||||
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
|
||||
FinishReason, InflightBatchingStats,
|
||||
IterationStats, KvCacheStats,
|
||||
@ -53,6 +54,9 @@ from .kv_cache_transceiver import KvCacheTransceiver
|
||||
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||
LlmResponse, get_draft_token_length)
|
||||
from .model_engine import ModelEngine
|
||||
from .request_utils import (RequestBroadcaster, attach_py_objects_to_requests,
|
||||
get_from_waiting_queue, merge_requests,
|
||||
schedule_attention_dp_requests)
|
||||
from .resource_manager import ResourceManager
|
||||
from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState,
|
||||
SampleStateTensors, TRTLLMSampler)
|
||||
@ -424,16 +428,36 @@ class PyExecutor:
|
||||
self._set_global_steady_clock_offset()
|
||||
self.executor_request_queue = ExecutorRequestQueue(
|
||||
dist=self.dist,
|
||||
enable_attention_dp=self.enable_attention_dp,
|
||||
max_batch_size=max_batch_size,
|
||||
max_beam_width=self.max_beam_width,
|
||||
max_num_active_requests=self.max_num_active_requests,
|
||||
enable_iter_perf_stats=self.enable_iter_perf_stats,
|
||||
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
|
||||
hang_detector=self.hang_detector,
|
||||
)
|
||||
self.executor_request_queue.set_exclude_last_generation_logits(
|
||||
self.disable_overlap_scheduler, self.dist.pp_size)
|
||||
# When overlap scheduler is enabled then when starting to handle a new prompt,
|
||||
# sample_async is called twice before the first call to update_requests:
|
||||
# - 1st time as a context request that handles on the 1st generated token
|
||||
# - 2nd time as a generation request that handles on the 2nd generated token.
|
||||
# and only after these two calls the sampler's update_request method is called.
|
||||
# So in a sampler that works by the expected flow of handling the logits in
|
||||
# sample_async, every update_request doesn't handle the newest token, but one
|
||||
# before it. Since all these calls work on the same request object, then its
|
||||
# logits storage contains the logits of both the token update_requests should work
|
||||
# on, and also its next token. Thus, excluding the last generation logits from any
|
||||
# getter is required.
|
||||
self.should_exclude_last_generation_logits = (
|
||||
not self.disable_overlap_scheduler and self.dist.pp_size == 1)
|
||||
|
||||
# Request processing state (managed by executor)
|
||||
self.canceled_req_ids: List[int] = []
|
||||
self.control_requests: List[RequestQueueItem] = []
|
||||
self.request_accumulated: List[RequestQueueItem] = []
|
||||
self.new_active_requests_queue_latency_ms = 0.0
|
||||
self._disable_mpi = mpi_disabled()
|
||||
self.request_broadcaster = RequestBroadcaster(self.dist,
|
||||
self.hang_detector)
|
||||
|
||||
# Waiting queue for requests that have been fetched but not yet scheduled
|
||||
self.waiting_queue: deque[RequestQueueItem] = deque()
|
||||
|
||||
self.control_request_barrier = threading.Event()
|
||||
self.control_action_done = threading.Event()
|
||||
|
||||
@ -698,7 +722,7 @@ class PyExecutor:
|
||||
@property
|
||||
def should_stop_processing(self):
|
||||
return self.is_shutdown and len(self.active_requests) == 0 and \
|
||||
self.executor_request_queue.get_waiting_queue_size() == 0
|
||||
len(self.waiting_queue) == 0
|
||||
|
||||
@contextmanager
|
||||
def _profiler(self):
|
||||
@ -778,8 +802,8 @@ class PyExecutor:
|
||||
f"iter = {self.iter_counter}, "
|
||||
f"global_rank = {self.global_rank}, "
|
||||
f"rank = {self.dist.rank}, "
|
||||
f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/"
|
||||
f"{self.executor_request_queue.num_fetch_requests}, "
|
||||
f"currank_total_requests = {self.num_fetch_requests_cur_rank}/"
|
||||
f"{self.num_fetch_requests}, "
|
||||
f"host_step_time = {host_step_time}ms, "
|
||||
f"prev_device_step_time = {prev_device_step_time}, "
|
||||
f"timestamp = {formatted_timestamp}, "
|
||||
@ -1143,8 +1167,7 @@ class PyExecutor:
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats = self._get_init_iter_stats(
|
||||
len(new_requests),
|
||||
self.executor_request_queue.
|
||||
get_new_active_requests_queue_latency())
|
||||
self._get_new_active_requests_queue_latency())
|
||||
|
||||
self._pad_attention_dp_dummy_request()
|
||||
|
||||
@ -1400,8 +1423,7 @@ class PyExecutor:
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats = self._get_init_iter_stats(
|
||||
len(new_requests),
|
||||
self.executor_request_queue.
|
||||
get_new_active_requests_queue_latency())
|
||||
self._get_new_active_requests_queue_latency())
|
||||
|
||||
self._pad_attention_dp_dummy_request()
|
||||
|
||||
@ -1639,14 +1661,14 @@ class PyExecutor:
|
||||
|
||||
def _handle_control_request(self):
|
||||
if len(self.active_requests) == 0 and \
|
||||
self.executor_request_queue.get_waiting_queue_size() == 0 and \
|
||||
len(self.executor_request_queue.control_requests) > 0:
|
||||
assert len(self.executor_request_queue.control_requests) == 1, (
|
||||
len(self.waiting_queue) == 0 and \
|
||||
len(self.control_requests) > 0:
|
||||
assert len(self.control_requests) == 1, (
|
||||
f"Expected exactly one control request to be processed at a time, "
|
||||
f"but found {len(self.executor_request_queue.control_requests)} control requests. "
|
||||
f"but found {len(self.control_requests)} control requests. "
|
||||
f"This may indicate a race condition or improper control request handling."
|
||||
)
|
||||
self.executor_request_queue.control_requests.pop(0)
|
||||
self.control_requests.pop(0)
|
||||
self.control_request_barrier.set()
|
||||
self.control_action_done.wait()
|
||||
self.control_action_done.clear()
|
||||
@ -1704,7 +1726,7 @@ class PyExecutor:
|
||||
# to ensure consistent batch sizes for accurate performance measurement.
|
||||
if not self.is_warmup and not can_forward:
|
||||
if self.enable_attention_dp:
|
||||
local_can_forward = self.executor_request_queue.num_fetch_requests + \
|
||||
local_can_forward = self.num_fetch_requests + \
|
||||
len(scheduled_batch.generation_requests) >= self.benchmark_req_queues_size
|
||||
all_can_forward = self.dist.tp_allgather(
|
||||
local_can_forward)
|
||||
@ -1714,7 +1736,7 @@ class PyExecutor:
|
||||
else:
|
||||
if self.dist.rank == 0:
|
||||
logger.info(
|
||||
f"sleep 10 seconds, num_fetched_requests: {self.executor_request_queue.num_fetch_requests}, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}"
|
||||
f"sleep 10 seconds, num_fetched_requests: {self.num_fetch_requests}, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}"
|
||||
)
|
||||
time.sleep(10)
|
||||
continue
|
||||
@ -2026,6 +2048,169 @@ class PyExecutor:
|
||||
self.model_engine.model.lm_head.num_embeddings):
|
||||
raise ValueError("Token ID out of range")
|
||||
|
||||
def _fetch_and_enqueue_requests(self,
|
||||
waiting_queue: deque[RequestQueueItem],
|
||||
total_num_active_requests: int) -> None:
|
||||
"""Fetch requests from request_queue and enqueue to waiting_queue."""
|
||||
# Block new requests while control requests are pending
|
||||
if len(self.control_requests) != 0:
|
||||
return
|
||||
|
||||
# Calculate timeout
|
||||
idle = (total_num_active_requests == 0) and len(waiting_queue) == 0
|
||||
if idle:
|
||||
# In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0
|
||||
# reaches the broadcast path regularly to prevent trtllm-serve timeout when idle.
|
||||
timeout = datetime.timedelta(
|
||||
seconds=1200) if self._disable_mpi else None
|
||||
else:
|
||||
timeout = datetime.timedelta(0)
|
||||
|
||||
# Fetch requests from rank 0
|
||||
new_requests = []
|
||||
if self.dist.rank == 0:
|
||||
# Process accumulated requests that were queued during control request handling.
|
||||
if len(self.request_accumulated) != 0:
|
||||
new_requests.extend(self.request_accumulated)
|
||||
self.request_accumulated.clear()
|
||||
# Reset timeout to 0 to avoid hanging when no new requests are available
|
||||
timeout = datetime.timedelta(0)
|
||||
with self.hang_detector.pause():
|
||||
new_requests.extend(
|
||||
self.executor_request_queue.get_from_request_queue(timeout))
|
||||
|
||||
# Broadcast requests and handle Python objects
|
||||
new_requests, py_request_objects = self.request_broadcaster.broadcast(
|
||||
new_requests)
|
||||
|
||||
# Validate and filter requests
|
||||
new_requests = self._handle_special_queue_items(new_requests)
|
||||
|
||||
# Attach Python objects to requests
|
||||
if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp
|
||||
or self.dist.cp_size
|
||||
> 1) and self.dist.rank > 0:
|
||||
attach_py_objects_to_requests(new_requests, py_request_objects)
|
||||
|
||||
waiting_queue.extend(new_requests)
|
||||
|
||||
def _pop_from_waiting_queue(
|
||||
self,
|
||||
waiting_queue: deque[RequestQueueItem],
|
||||
total_num_active_requests: int,
|
||||
all_ranks_num_active_requests: Optional[List[int]] = None
|
||||
) -> List[RequestQueueItem]:
|
||||
"""Pop requests from waiting_queue based on available capacity."""
|
||||
if self.enable_attention_dp:
|
||||
total_max = self.dist.tp_size * self.max_num_active_requests
|
||||
else:
|
||||
total_max = self.max_num_active_requests
|
||||
|
||||
max_new_requests = total_max - total_num_active_requests
|
||||
|
||||
return get_from_waiting_queue(
|
||||
waiting_queue,
|
||||
max_new_requests,
|
||||
enable_attention_dp=self.enable_attention_dp,
|
||||
max_num_active_requests=self.max_num_active_requests,
|
||||
all_ranks_num_active_requests=all_ranks_num_active_requests)
|
||||
|
||||
@nvtx_range("_fetch_new_requests")
|
||||
def _fetch_new_requests(
|
||||
self, waiting_queue: deque[RequestQueueItem],
|
||||
activate_requests: List[LlmRequest]) -> List[LlmRequest]:
|
||||
"""Fetch new requests and return LlmRequests ready for execution."""
|
||||
# 1. Gather info and calculate total_num_active_requests
|
||||
if self.enable_attention_dp:
|
||||
all_ranks_num_active_requests = []
|
||||
all_ranks_num_active_tokens = []
|
||||
if self.dist.has_cp_helix:
|
||||
num_active_tokens = sum(
|
||||
[req.total_input_len_cp for req in activate_requests])
|
||||
else:
|
||||
num_active_tokens = sum(
|
||||
[req.py_orig_prompt_len for req in activate_requests])
|
||||
responses_list = self.dist.tp_allgather(
|
||||
[len(activate_requests), num_active_tokens])
|
||||
for num_active_requests, num_active_tokens in responses_list:
|
||||
all_ranks_num_active_requests.append(num_active_requests)
|
||||
all_ranks_num_active_tokens.append(num_active_tokens)
|
||||
total_num_active_requests = sum(all_ranks_num_active_requests)
|
||||
else:
|
||||
total_num_active_requests = len(activate_requests)
|
||||
all_ranks_num_active_requests = None
|
||||
|
||||
# 2. Fetch and enqueue to waiting queue
|
||||
self._fetch_and_enqueue_requests(waiting_queue,
|
||||
total_num_active_requests)
|
||||
|
||||
# 3. Pop requests from waiting queue
|
||||
new_requests = self._pop_from_waiting_queue(
|
||||
waiting_queue, total_num_active_requests,
|
||||
all_ranks_num_active_requests)
|
||||
|
||||
# 4. Update performance metrics (before DP scheduling to clear all start_times)
|
||||
if self.enable_iter_perf_stats and self.dist.rank == 0:
|
||||
self._update_new_active_requests_queue_latency(new_requests)
|
||||
|
||||
# 5. Schedule requests across ranks (DP only)
|
||||
if self.enable_attention_dp:
|
||||
all_ranks_new_requests, self.expected_num_active_requests = \
|
||||
schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens, self.dist.tp_size,
|
||||
self.max_num_active_requests)
|
||||
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]
|
||||
|
||||
# Update counters for DP
|
||||
self.num_fetch_requests += len(new_requests)
|
||||
self.num_fetch_requests_cur_rank += len(new_requests_cur_rank)
|
||||
|
||||
new_requests = new_requests_cur_rank
|
||||
|
||||
# 6. Merge requests
|
||||
return merge_requests(new_requests,
|
||||
cp_config=self.dist.cp_config,
|
||||
cp_rank=self.dist.cp_rank,
|
||||
cp_size=self.dist.cp_size,
|
||||
exclude_last_generation_logits=self.
|
||||
_should_exclude_last_generation_logits())
|
||||
|
||||
def _handle_special_queue_items(
|
||||
self,
|
||||
new_requests: List[RequestQueueItem]) -> List[RequestQueueItem]:
|
||||
"""Handle special signals."""
|
||||
accepted_new_requests = []
|
||||
for idx, req_item in enumerate(new_requests):
|
||||
if req_item.is_shutdown_request:
|
||||
self.is_shutdown = True
|
||||
break
|
||||
elif req_item.is_canceled_request:
|
||||
self.canceled_req_ids.append(req_item.id)
|
||||
elif req_item.is_control_request:
|
||||
self.control_requests.append(req_item)
|
||||
if self.dist.rank == 0:
|
||||
self.request_accumulated.extend(new_requests[idx + 1:])
|
||||
break
|
||||
else:
|
||||
accepted_new_requests.append(req_item)
|
||||
|
||||
return accepted_new_requests
|
||||
|
||||
def _update_new_active_requests_queue_latency(
|
||||
self, new_requests: List[RequestQueueItem]):
|
||||
"""Update queue latency metrics for new requests."""
|
||||
now = time.time()
|
||||
latency = self.executor_request_queue.calculate_queue_latency(
|
||||
new_requests, now)
|
||||
self.new_active_requests_queue_latency_ms += latency
|
||||
|
||||
def _get_new_active_requests_queue_latency(self) -> float:
|
||||
return self.new_active_requests_queue_latency_ms
|
||||
|
||||
def _should_exclude_last_generation_logits(self) -> bool:
|
||||
return self.should_exclude_last_generation_logits
|
||||
|
||||
def _fetch_and_activate_new_requests(self) -> List[LlmRequest]:
|
||||
|
||||
def _respond_if_invalid(request: LlmRequest) -> bool:
|
||||
@ -2041,11 +2226,8 @@ class PyExecutor:
|
||||
self._handle_errors(str(e), requests=[request])
|
||||
return True
|
||||
|
||||
new_requests_cur_rank = self.executor_request_queue.fetch_new_requests(
|
||||
self.active_requests)
|
||||
self.is_shutdown = self.executor_request_queue.is_shutdown
|
||||
self.expected_num_active_requests = self.executor_request_queue.get_expected_num_active_requests(
|
||||
)
|
||||
new_requests_cur_rank = self._fetch_new_requests(
|
||||
self.waiting_queue, self.active_requests)
|
||||
|
||||
validated_requests = [
|
||||
request for request in new_requests_cur_rank
|
||||
@ -2647,20 +2829,20 @@ class PyExecutor:
|
||||
|
||||
@nvtx_range("_handle_canceled_requests")
|
||||
def _handle_canceled_requests(self):
|
||||
if self.executor_request_queue.get_canceled_req_ids_size() == 0:
|
||||
if len(self.canceled_req_ids) == 0:
|
||||
return
|
||||
|
||||
# Remove cancel request in the waiting queue
|
||||
self.executor_request_queue.update_waiting_queue()
|
||||
|
||||
# Create set from list of canceled request ids to speed up canceled test
|
||||
canceled_req_ids = set(
|
||||
self.executor_request_queue.get_canceled_req_ids())
|
||||
canceled_req_ids_set = set(self.canceled_req_ids)
|
||||
|
||||
# Remove canceled requests from the waiting queue
|
||||
self.waiting_queue = deque(req for req in self.waiting_queue
|
||||
if req.id not in canceled_req_ids_set)
|
||||
|
||||
still_pending_canceled_ids = []
|
||||
for request in self.active_requests:
|
||||
req_id = request.py_request_id if not request.is_child else request.parent_request_id
|
||||
if req_id not in canceled_req_ids:
|
||||
if req_id not in canceled_req_ids_set:
|
||||
continue
|
||||
|
||||
is_cancelled = self._try_cancel_request(request)
|
||||
@ -2673,9 +2855,8 @@ class PyExecutor:
|
||||
still_pending_canceled_ids.append(req_id)
|
||||
|
||||
# Clear list of requests marked for cancellation and add back those that failed to cancel.
|
||||
self.executor_request_queue.canceled_req_ids.clear()
|
||||
self.executor_request_queue.canceled_req_ids.extend(
|
||||
still_pending_canceled_ids)
|
||||
self.canceled_req_ids.clear()
|
||||
self.canceled_req_ids.extend(still_pending_canceled_ids)
|
||||
|
||||
@nvtx_range("_enqueue_responses")
|
||||
def _enqueue_responses(self, responses: Iterable[Tuple[int, LlmResponse]]):
|
||||
|
||||
704
tensorrt_llm/_torch/pyexecutor/request_utils.py
Normal file
704
tensorrt_llm/_torch/pyexecutor/request_utils.py
Normal file
@ -0,0 +1,704 @@
|
||||
"""Utility functions for request processing."""
|
||||
|
||||
import heapq
|
||||
import os
|
||||
from collections import deque, namedtuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
from tensorrt_llm.mapping import CpType
|
||||
|
||||
from ..distributed import Distributed
|
||||
from .hang_detector import HangDetector
|
||||
from .llm_request import ExecutorRequest, LlmRequest, executor_request_to_llm_request
|
||||
|
||||
# Type alias for request queue items (to avoid circular import)
|
||||
# The actual RequestQueueItem class is defined in executor_request_queue.py
|
||||
|
||||
|
||||
def get_num_child_requests(request: ExecutorRequest) -> int:
|
||||
"""Get the number of child requests for a given request.
|
||||
|
||||
Args:
|
||||
request: The executor request to check.
|
||||
|
||||
Returns:
|
||||
Number of child requests (0 if beam search, otherwise num_return_sequences - 1).
|
||||
"""
|
||||
sampling_config = request.sampling_config
|
||||
return 0 if sampling_config.beam_width > 1 else (sampling_config.num_return_sequences or 1) - 1
|
||||
|
||||
|
||||
def collect_py_objects_from_requests(
|
||||
requests: List, attribute_name: str
|
||||
) -> Optional[Tuple[str, Dict]]:
|
||||
"""Collect Python-only objects from requests.
|
||||
|
||||
Args:
|
||||
requests: List of RequestQueueItem objects.
|
||||
attribute_name: Name of the attribute to collect.
|
||||
|
||||
Returns:
|
||||
Tuple of (attribute_name, dict mapping request_id to object) or None if empty.
|
||||
"""
|
||||
req_id_to_obj = {}
|
||||
for item in requests:
|
||||
if not item.is_normal_request:
|
||||
continue
|
||||
if item.request:
|
||||
obj = getattr(item.request, attribute_name, None)
|
||||
if obj is not None:
|
||||
req_id_to_obj[item.id] = obj
|
||||
return None if not req_id_to_obj else (attribute_name, req_id_to_obj)
|
||||
|
||||
|
||||
def attach_py_objects_to_requests(requests: List, py_request_objects: Tuple) -> None:
|
||||
"""Attach Python-only objects to each request.
|
||||
|
||||
Args:
|
||||
requests: List of RequestQueueItem objects.
|
||||
py_request_objects: Tuple of (attribute_name, dict) pairs.
|
||||
"""
|
||||
for attr_name, req_obj_dict in py_request_objects:
|
||||
for item in requests:
|
||||
if item.request:
|
||||
py_obj = req_obj_dict.get(item.id)
|
||||
if py_obj is not None:
|
||||
setattr(item.request, attr_name, py_obj)
|
||||
|
||||
|
||||
def schedule_attention_dp_requests(
|
||||
new_requests: List[Any],
|
||||
all_ranks_num_active_requests: List[int],
|
||||
all_ranks_num_active_tokens: List[int],
|
||||
tp_size: int,
|
||||
max_num_active_requests: int,
|
||||
) -> Tuple[Dict[int, List[Any]], int]:
|
||||
"""Schedule attention DP requests across ranks.
|
||||
|
||||
This function distributes requests across tensor parallel ranks for attention DP.
|
||||
It first tries to assign requests to their target dp_rank (if specified and has capacity),
|
||||
then balances the remaining requests across all ranks.
|
||||
|
||||
Args:
|
||||
new_requests: List of RequestQueueItem to schedule.
|
||||
all_ranks_num_active_requests: Number of active requests per rank (will be modified).
|
||||
all_ranks_num_active_tokens: Number of active tokens per rank.
|
||||
tp_size: Number of tensor parallel ranks.
|
||||
max_num_active_requests: Maximum number of active requests per rank.
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- all_ranks_new_requests: Dict mapping rank to list of assigned requests.
|
||||
- expected_num_active_requests: Expected number of active requests per rank.
|
||||
"""
|
||||
# Map from ranks to new requests
|
||||
all_ranks_new_requests = {tp_rank: [] for tp_rank in range(tp_size)}
|
||||
|
||||
# Prioritize the requests that are not in relax mode
|
||||
def get_relax_value(req_item):
|
||||
scheduling_params = getattr(req_item.request, "py_scheduling_params", None)
|
||||
if scheduling_params is None:
|
||||
return True
|
||||
return scheduling_params.attention_dp_relax
|
||||
|
||||
new_requests = sorted(new_requests, key=get_relax_value)
|
||||
|
||||
# Try to put the requests to the target dp rank until the max_num_active_requests is reached
|
||||
remaining_unscheduled = []
|
||||
for req_item in new_requests:
|
||||
scheduled = False
|
||||
scheduling_params = getattr(req_item.request, "py_scheduling_params", None)
|
||||
if scheduling_params is not None:
|
||||
target_dp_rank = scheduling_params.attention_dp_rank
|
||||
if (
|
||||
target_dp_rank is not None
|
||||
and all_ranks_num_active_requests[target_dp_rank] < max_num_active_requests
|
||||
):
|
||||
all_ranks_num_active_requests[target_dp_rank] += 1
|
||||
scheduled = True
|
||||
all_ranks_new_requests[target_dp_rank].append(req_item)
|
||||
|
||||
if not scheduled:
|
||||
remaining_unscheduled.append(req_item)
|
||||
|
||||
# Balance the remaining unscheduled requests across ranks
|
||||
num_new_requests_all_ranks = len(remaining_unscheduled)
|
||||
total_num_active_requests = sum(all_ranks_num_active_requests)
|
||||
expected_num_active_requests = max(
|
||||
(total_num_active_requests + num_new_requests_all_ranks + tp_size - 1) // tp_size,
|
||||
max(all_ranks_num_active_requests),
|
||||
)
|
||||
|
||||
all_ranks_new_requests = balance_requests_across_ranks(
|
||||
remaining_unscheduled,
|
||||
all_ranks_new_requests,
|
||||
all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens,
|
||||
expected_num_active_requests,
|
||||
)
|
||||
|
||||
return all_ranks_new_requests, expected_num_active_requests
|
||||
|
||||
|
||||
def balance_requests_across_ranks(
|
||||
new_requests: List,
|
||||
all_ranks_new_requests: Dict[int, List],
|
||||
all_ranks_num_active_requests: List[int],
|
||||
all_ranks_num_active_tokens: List[int],
|
||||
expected_num_active_requests: int,
|
||||
) -> Dict[int, List]:
|
||||
"""Balance requests across ranks for attention DP.
|
||||
|
||||
Uses a heap-based algorithm to distribute requests evenly across ranks,
|
||||
prioritizing ranks with fewer tokens for better load balancing.
|
||||
|
||||
Args:
|
||||
new_requests: List of new requests to distribute.
|
||||
all_ranks_new_requests: Dict mapping rank to list of already assigned requests.
|
||||
all_ranks_num_active_requests: Number of active requests per rank.
|
||||
all_ranks_num_active_tokens: Number of active tokens per rank.
|
||||
expected_num_active_requests: Target number of active requests per rank.
|
||||
|
||||
Returns:
|
||||
Updated all_ranks_new_requests dict with new requests distributed.
|
||||
"""
|
||||
if new_requests:
|
||||
# Balance context tokens across ranks using heap
|
||||
HeapVal = namedtuple("HeapVal", ["num_tokens", "num_requests", "rank", "request_list"])
|
||||
|
||||
all_ranks_new_requests_heap = [
|
||||
HeapVal(all_ranks_num_active_tokens[tp_rank], val, tp_rank, [])
|
||||
for tp_rank, val in enumerate(all_ranks_num_active_requests)
|
||||
]
|
||||
|
||||
all_ranks_new_requests_heap = [
|
||||
val
|
||||
for val in all_ranks_new_requests_heap
|
||||
if val.num_requests < expected_num_active_requests
|
||||
]
|
||||
|
||||
all_ranks_new_scheduled_requests = {
|
||||
val.rank: val.request_list for val in all_ranks_new_requests_heap
|
||||
}
|
||||
|
||||
heapq.heapify(all_ranks_new_requests_heap)
|
||||
|
||||
# Sort by token count (descending) for better load balancing
|
||||
new_requests = sorted(
|
||||
new_requests,
|
||||
key=lambda x: len(getattr(x.request, "input_token_ids", [])) if x.request else 0,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Distribute requests across ranks
|
||||
for req_item in new_requests:
|
||||
val = heapq.heappop(all_ranks_new_requests_heap)
|
||||
token_count = (
|
||||
len(getattr(req_item.request, "input_token_ids", [])) if req_item.request else 0
|
||||
)
|
||||
# Update the heap value with the new request
|
||||
val = val._replace(
|
||||
num_tokens=val.num_tokens + token_count,
|
||||
num_requests=val.num_requests + 1,
|
||||
)
|
||||
|
||||
val.request_list.append(req_item)
|
||||
# If rank still has room for new requests, push back into heap
|
||||
if val.num_requests < expected_num_active_requests:
|
||||
heapq.heappush(all_ranks_new_requests_heap, val)
|
||||
|
||||
# Extend all_ranks_new_requests with the new requests that have been scheduled
|
||||
for rank, reqs in all_ranks_new_scheduled_requests.items():
|
||||
all_ranks_new_requests[rank].extend(reqs)
|
||||
|
||||
return all_ranks_new_requests
|
||||
|
||||
|
||||
def can_process_attention_dp_request(
|
||||
req_item, all_ranks_num_active_requests: List[int], max_num_active_requests: int
|
||||
) -> bool:
|
||||
"""Check if a request can be processed immediately for attention DP.
|
||||
|
||||
Args:
|
||||
req_item: The request queue item to check.
|
||||
all_ranks_num_active_requests: Number of active requests for each rank.
|
||||
max_num_active_requests: Maximum number of active requests per rank.
|
||||
|
||||
Returns:
|
||||
True if the request can be processed, False otherwise.
|
||||
"""
|
||||
scheduling_params = getattr(req_item.request, "py_scheduling_params", None)
|
||||
if scheduling_params is None:
|
||||
return True
|
||||
|
||||
target_dp_rank = scheduling_params.attention_dp_rank
|
||||
if target_dp_rank is None or scheduling_params.attention_dp_relax:
|
||||
return True
|
||||
|
||||
if all_ranks_num_active_requests[target_dp_rank] < max_num_active_requests:
|
||||
all_ranks_num_active_requests[target_dp_rank] += 1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_from_waiting_queue(
|
||||
waiting_queue: deque,
|
||||
max_req_count: int,
|
||||
enable_attention_dp: bool,
|
||||
max_num_active_requests: int,
|
||||
all_ranks_num_active_requests: Optional[List[int]] = None,
|
||||
) -> List:
|
||||
"""Get requests from the waiting queue.
|
||||
|
||||
Args:
|
||||
waiting_queue: The queue to pop items from.
|
||||
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
|
||||
enable_attention_dp: Whether to enable attention DP scheduling.
|
||||
max_num_active_requests: Maximum number of active requests per rank.
|
||||
all_ranks_num_active_requests: Number of active requests for each rank.
|
||||
|
||||
Returns:
|
||||
List of requests that can be processed.
|
||||
"""
|
||||
if max_req_count <= 0:
|
||||
return []
|
||||
|
||||
req_count = 0
|
||||
items = []
|
||||
pending_requests = []
|
||||
|
||||
# Track the request with strict requirements
|
||||
scheduling_all_ranks_num_active_requests = (
|
||||
all_ranks_num_active_requests.copy() if enable_attention_dp else None
|
||||
)
|
||||
|
||||
while req_count < max_req_count and waiting_queue:
|
||||
req_item = waiting_queue[0]
|
||||
num_children = len(req_item.child_req_ids) if req_item.child_req_ids else 0
|
||||
if (req_count + 1 + num_children) > max_req_count:
|
||||
break
|
||||
req_item = waiting_queue.popleft()
|
||||
|
||||
can_process = (
|
||||
can_process_attention_dp_request(
|
||||
req_item, scheduling_all_ranks_num_active_requests, max_num_active_requests
|
||||
)
|
||||
if enable_attention_dp
|
||||
else True
|
||||
)
|
||||
|
||||
if can_process:
|
||||
items.append(req_item)
|
||||
req_count += 1 + num_children
|
||||
else:
|
||||
pending_requests.append(req_item)
|
||||
|
||||
# Put the pending requests back to the waiting queue
|
||||
# All ranks should have the same waiting queue
|
||||
waiting_queue.extendleft(reversed(pending_requests))
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def partition_context_for_star_attention(
|
||||
ctx_ids_list: List[int], cp_rank: int, cp_size: int, block_size: int, anchor_block_size: int
|
||||
) -> Tuple[List[List[int]], List[List[int]], int]:
|
||||
"""Partition context for Star Attention CP.
|
||||
|
||||
Args:
|
||||
ctx_ids_list: List of context token IDs.
|
||||
cp_rank: Current CP rank.
|
||||
cp_size: Total number of CP ranks.
|
||||
block_size: Size of each block.
|
||||
anchor_block_size: Size of anchor block.
|
||||
|
||||
Returns:
|
||||
Tuple of (ctx_blocks, position_blocks, padding).
|
||||
"""
|
||||
ctx_ids = torch.tensor(ctx_ids_list).unsqueeze(0)
|
||||
ctx_len = ctx_ids.shape[-1]
|
||||
|
||||
if block_size is None:
|
||||
block_size = ctx_len // cp_size
|
||||
if anchor_block_size is None:
|
||||
anchor_block_size = block_size
|
||||
|
||||
assert anchor_block_size <= block_size, (
|
||||
f"cp_anchor_size {anchor_block_size} should be smaller than block_size {block_size}"
|
||||
)
|
||||
|
||||
padding = 0
|
||||
if ctx_len % block_size != 0:
|
||||
padding = block_size - (ctx_len % block_size)
|
||||
assert padding <= ctx_len, "block size is too large for context, please set it smaller"
|
||||
ctx_ids = torch.cat((ctx_ids, torch.zeros_like(ctx_ids)[:, :padding]), dim=-1)
|
||||
position_ids = torch.arange(0, ctx_ids.shape[-1]).unsqueeze(0)
|
||||
|
||||
ctx_ids_blocks = torch.tensor_split(torch.stack(ctx_ids.split(block_size, dim=-1)), cp_size)
|
||||
position_ids_blocks = torch.tensor_split(
|
||||
torch.stack(position_ids.split(block_size, dim=-1)), cp_size
|
||||
)
|
||||
|
||||
if cp_rank != 0:
|
||||
ctx_blocks = [ctx_ids_blocks[0][0].tolist()[0][:anchor_block_size]]
|
||||
position_blocks = [position_ids_blocks[0][0].tolist()[0][:anchor_block_size]]
|
||||
else:
|
||||
ctx_blocks, position_blocks = [], []
|
||||
|
||||
for idx in range(len(ctx_ids_blocks[cp_rank])):
|
||||
ctx_block = ctx_ids_blocks[cp_rank][idx]
|
||||
position_block = position_ids_blocks[cp_rank][idx]
|
||||
ctx_blocks.append(ctx_block.tolist()[0])
|
||||
position_blocks.append(position_block.tolist()[0])
|
||||
|
||||
return ctx_blocks, position_blocks, padding
|
||||
|
||||
|
||||
def partition_context_for_helix(
|
||||
input_token_ids: List[int], cp_rank: int, cp_size: int, tokens_per_block: int
|
||||
) -> Tuple[List[int], List[int], int, int]:
|
||||
"""Partition context for Helix CP.
|
||||
|
||||
Args:
|
||||
input_token_ids: List of input token IDs.
|
||||
cp_rank: Current CP rank.
|
||||
cp_size: Total number of CP ranks.
|
||||
tokens_per_block: Number of tokens per block.
|
||||
|
||||
Returns:
|
||||
Tuple of (input_ids_this_rank, position_ids_this_rank, input_len, padding_len).
|
||||
|
||||
Raises:
|
||||
ValueError: If there aren't enough tokens for at least one block per CP rank.
|
||||
"""
|
||||
all_input_ids = torch.tensor(input_token_ids, dtype=torch.int64).unsqueeze(0)
|
||||
input_len = all_input_ids.shape[-1]
|
||||
|
||||
num_total_blocks = (input_len + tokens_per_block - 1) // tokens_per_block
|
||||
if num_total_blocks < cp_size:
|
||||
raise ValueError(
|
||||
f"There aren't enough tokens to get at least one block per CP rank. "
|
||||
f"num_total_blocks {num_total_blocks} < num_cp_ranks {cp_size}. "
|
||||
f"Please use smaller tokens_per_block for KV cache or reduce the number of CP ranks."
|
||||
)
|
||||
|
||||
# Padding to ensure torch.stack used with torch.tensor_split works properly.
|
||||
padding_len = 0
|
||||
if input_len % tokens_per_block != 0:
|
||||
padding_len = tokens_per_block - (input_len % tokens_per_block)
|
||||
padding_ids = torch.zeros([1, padding_len], dtype=torch.int64)
|
||||
all_input_ids = torch.cat((all_input_ids, padding_ids), dim=-1)
|
||||
all_position_ids = torch.arange(0, input_len + padding_len, dtype=torch.int64).unsqueeze(0)
|
||||
|
||||
input_id_blocks_per_rank = torch.tensor_split(
|
||||
torch.stack(all_input_ids.split(tokens_per_block, dim=-1)), cp_size
|
||||
)
|
||||
position_id_blocks_per_rank = torch.tensor_split(
|
||||
torch.stack(all_position_ids.split(tokens_per_block, dim=-1)), cp_size
|
||||
)
|
||||
|
||||
# Get the input_ids and position_ids for this rank.
|
||||
input_ids_this_rank = input_id_blocks_per_rank[cp_rank].flatten().tolist()
|
||||
position_ids_this_rank = position_id_blocks_per_rank[cp_rank].flatten().tolist()
|
||||
|
||||
# Undo the padding. Only last rank's last block will be padded right now
|
||||
# given contiguous block assignment.
|
||||
if cp_rank == cp_size - 1 and padding_len > 0:
|
||||
input_ids_this_rank = input_ids_this_rank[:-padding_len]
|
||||
position_ids_this_rank = position_ids_this_rank[:-padding_len]
|
||||
|
||||
return input_ids_this_rank, position_ids_this_rank, input_len, padding_len
|
||||
|
||||
|
||||
def merge_requests_to_llm_requests(
|
||||
new_requests: List, exclude_last_generation_logits: bool
|
||||
) -> List[LlmRequest]:
|
||||
"""Merge RequestQueueItems to LlmRequests (basic case without CP).
|
||||
|
||||
Args:
|
||||
new_requests: List of RequestQueueItem objects.
|
||||
exclude_last_generation_logits: Whether to exclude last generation logits.
|
||||
|
||||
Returns:
|
||||
List of LlmRequest objects including child requests.
|
||||
"""
|
||||
req_with_children = []
|
||||
for req_item in new_requests:
|
||||
req = executor_request_to_llm_request(
|
||||
req_item.id, req_item.request, req_item.child_req_ids, exclude_last_generation_logits
|
||||
)
|
||||
req_with_children.append(req)
|
||||
if req.child_requests:
|
||||
req_with_children.extend(req.child_requests)
|
||||
return req_with_children
|
||||
|
||||
|
||||
def merge_helix_requests(
|
||||
new_requests: List,
|
||||
cp_rank: int,
|
||||
cp_size: int,
|
||||
tokens_per_block: int,
|
||||
exclude_last_generation_logits: bool,
|
||||
) -> List[LlmRequest]:
|
||||
"""Merge requests for Helix CP.
|
||||
|
||||
Note: Helix parallelism is a decode-only feature run with disaggregated serving.
|
||||
This function gets called on gen server during initialization of a new request.
|
||||
|
||||
Args:
|
||||
new_requests: List of RequestQueueItem objects.
|
||||
cp_rank: Current CP rank.
|
||||
cp_size: Total number of CP ranks.
|
||||
tokens_per_block: Number of tokens per block.
|
||||
exclude_last_generation_logits: Whether to exclude last generation logits.
|
||||
|
||||
Returns:
|
||||
List of LlmRequest objects including child requests.
|
||||
"""
|
||||
req_with_children = []
|
||||
|
||||
for req_item in new_requests:
|
||||
input_ids_this_rank, position_ids_this_rank, input_len, _ = partition_context_for_helix(
|
||||
req_item.request.input_token_ids, cp_rank, cp_size, tokens_per_block
|
||||
)
|
||||
|
||||
req = executor_request_to_llm_request(
|
||||
req_id=req_item.id,
|
||||
executor_request=req_item.request,
|
||||
child_req_ids=req_item.child_req_ids,
|
||||
exclude_last_generation_logits=exclude_last_generation_logits,
|
||||
input_token_ids=input_ids_this_rank,
|
||||
position_ids=position_ids_this_rank,
|
||||
)
|
||||
req.total_input_len_cp = input_len
|
||||
req.seqlen_this_rank_cp = len(input_ids_this_rank)
|
||||
req_with_children.append(req)
|
||||
if req.child_requests:
|
||||
req_with_children.extend(req.child_requests)
|
||||
|
||||
return req_with_children
|
||||
|
||||
|
||||
def merge_star_attention_requests(
|
||||
new_requests: List,
|
||||
cp_rank: int,
|
||||
cp_size: int,
|
||||
cp_config: dict,
|
||||
exclude_last_generation_logits: bool,
|
||||
) -> List[LlmRequest]:
|
||||
"""Merge requests for Star Attention CP.
|
||||
|
||||
Args:
|
||||
new_requests: List of RequestQueueItem objects.
|
||||
cp_rank: Current CP rank.
|
||||
cp_size: Total number of CP ranks.
|
||||
cp_config: CP configuration dict containing 'block_size' and 'cp_anchor_size'.
|
||||
exclude_last_generation_logits: Whether to exclude last generation logits.
|
||||
|
||||
Returns:
|
||||
List of LlmRequest objects.
|
||||
"""
|
||||
result = []
|
||||
block_size = cp_config["block_size"]
|
||||
anchor_block_size = cp_config["cp_anchor_size"]
|
||||
|
||||
for req_item in new_requests:
|
||||
req_id, exe_req, query_token_ids = req_item.id, req_item.request, req_item.query
|
||||
ctx_len0 = len(exe_req.input_token_ids)
|
||||
|
||||
ctx_blocks, position_blocks, last_block_padding_num = partition_context_for_star_attention(
|
||||
exe_req.input_token_ids, cp_rank, cp_size, block_size, anchor_block_size
|
||||
)
|
||||
|
||||
if cp_rank == cp_size - 1 and last_block_padding_num > 0:
|
||||
ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num]
|
||||
position_blocks[-1] = position_blocks[-1][:-last_block_padding_num]
|
||||
|
||||
# if has query
|
||||
if query_token_ids:
|
||||
ctx_blocks.append(query_token_ids)
|
||||
position_blocks.append([i for i in range(ctx_len0, ctx_len0 + len(query_token_ids))])
|
||||
|
||||
# insert the dummy block to align the number of ctx iterations of each rank
|
||||
total_blocks = (ctx_len0 + block_size - 1) // block_size
|
||||
num_blocks_per_rank = (total_blocks + cp_size - 1) // cp_size + 1 # 1 for query block
|
||||
if len(ctx_blocks) == num_blocks_per_rank:
|
||||
ctx_blocks.insert(1, [])
|
||||
position_blocks.insert(1, [])
|
||||
elif len(ctx_blocks) == num_blocks_per_rank + 1:
|
||||
# anchor + ctx_blocks + qry_block
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid context partition: rank = {cp_rank}, "
|
||||
f"len(ctx_blocks) = {len(ctx_blocks)}, "
|
||||
f"num_blocks_per_rank = {num_blocks_per_rank}"
|
||||
)
|
||||
|
||||
# fake data for scheduler
|
||||
ctx_blocks_list = [0] * (block_size + anchor_block_size)
|
||||
|
||||
req = executor_request_to_llm_request(
|
||||
req_id, exe_req, exclude_last_generation_logits, ctx_blocks_list
|
||||
)
|
||||
req.gen_iters = 0
|
||||
req.ctx_iters = 0
|
||||
req.ctx_blocks = ctx_blocks
|
||||
req.ctx_position_blocks = position_blocks
|
||||
req.query_id = query_token_ids
|
||||
|
||||
result.append(req)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@nvtx_range("merge_requests")
|
||||
def merge_requests(
|
||||
new_requests: List,
|
||||
cp_config: dict,
|
||||
cp_rank: int,
|
||||
cp_size: int,
|
||||
exclude_last_generation_logits: bool,
|
||||
) -> List[LlmRequest]:
|
||||
"""Merge RequestQueueItems to LlmRequests based on CP configuration.
|
||||
|
||||
This is a router function that dispatches to the appropriate merge function
|
||||
based on the CP (Context Parallelism) configuration.
|
||||
|
||||
Args:
|
||||
new_requests: List of RequestQueueItem objects.
|
||||
cp_config: CP configuration dict. May contain 'cp_type', 'tokens_per_block',
|
||||
'block_size', 'cp_anchor_size'.
|
||||
cp_rank: Current CP rank.
|
||||
cp_size: Total number of CP ranks.
|
||||
exclude_last_generation_logits: Whether to exclude last generation logits.
|
||||
|
||||
Returns:
|
||||
List of LlmRequest objects.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If cp_type is not supported.
|
||||
"""
|
||||
if "cp_type" in cp_config:
|
||||
cp_type = cp_config["cp_type"]
|
||||
if cp_type == CpType.STAR:
|
||||
return merge_star_attention_requests(
|
||||
new_requests,
|
||||
cp_rank=cp_rank,
|
||||
cp_size=cp_size,
|
||||
cp_config=cp_config,
|
||||
exclude_last_generation_logits=exclude_last_generation_logits,
|
||||
)
|
||||
elif cp_type == CpType.HELIX:
|
||||
return merge_helix_requests(
|
||||
new_requests,
|
||||
cp_rank=cp_rank,
|
||||
cp_size=cp_size,
|
||||
tokens_per_block=cp_config["tokens_per_block"],
|
||||
exclude_last_generation_logits=exclude_last_generation_logits,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported cp type {cp_type.name}.")
|
||||
|
||||
return merge_requests_to_llm_requests(new_requests, exclude_last_generation_logits)
|
||||
|
||||
|
||||
class RequestBroadcaster:
|
||||
"""Broadcasts requests across distributed ranks (TP, PP, CP)."""
|
||||
|
||||
def __init__(self, dist: Distributed, hang_detector: HangDetector):
|
||||
self.dist = dist
|
||||
self.hang_detector = hang_detector
|
||||
self.send_requests_handler = None
|
||||
|
||||
def broadcast(self, new_requests: List) -> Tuple[List, Optional[Tuple]]:
|
||||
"""Broadcast requests and Python objects across ranks."""
|
||||
if self.dist.rank == 0:
|
||||
py_request_objects = self._collect_py_objects(new_requests)
|
||||
else:
|
||||
py_request_objects = None
|
||||
|
||||
if self.dist.rank == 0:
|
||||
# Preserve original `new_requests` on rank 0
|
||||
_ = self._broadcast_requests(new_requests, py_request_objects)
|
||||
else:
|
||||
with self.hang_detector.pause():
|
||||
new_requests, py_request_objects = self._broadcast_requests(
|
||||
new_requests, py_request_objects
|
||||
)
|
||||
|
||||
return new_requests, py_request_objects
|
||||
|
||||
def _collect_py_objects(self, new_requests: List) -> Tuple:
|
||||
"""Collect Python-only objects from requests."""
|
||||
py_logits_post_processors = collect_py_objects_from_requests(
|
||||
new_requests, "py_logits_post_processors"
|
||||
)
|
||||
py_multimodal_data = collect_py_objects_from_requests(new_requests, "py_multimodal_data")
|
||||
py_scheduling_params = collect_py_objects_from_requests(
|
||||
new_requests, "py_scheduling_params"
|
||||
)
|
||||
py_num_logprobs = collect_py_objects_from_requests(new_requests, "py_num_logprobs")
|
||||
py_disaggregated_params = collect_py_objects_from_requests(
|
||||
new_requests, "py_disaggregated_params"
|
||||
)
|
||||
|
||||
return tuple(
|
||||
filter(
|
||||
None,
|
||||
[
|
||||
py_logits_post_processors,
|
||||
py_multimodal_data,
|
||||
py_scheduling_params,
|
||||
py_num_logprobs,
|
||||
py_disaggregated_params,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@nvtx_range("broadcast_requests")
|
||||
def _broadcast_requests(
|
||||
self, new_requests: List, py_request_objects
|
||||
) -> Tuple[List, Optional[Dict]]:
|
||||
"""Broadcast requests across pipeline stages."""
|
||||
payloads = (new_requests, py_request_objects)
|
||||
|
||||
if not self.dist.has_pp:
|
||||
return self.dist.broadcast(payloads, root=0)
|
||||
|
||||
# Broadcast within first PP stage before send/recv chain to other PP stages.
|
||||
# This needs to cover both TP and CP ranks within the first PP stage.
|
||||
if self.dist.is_first_pp_rank:
|
||||
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
|
||||
|
||||
# Tag for communication
|
||||
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
|
||||
|
||||
# Send payloads
|
||||
if not self.dist.is_first_pp_rank:
|
||||
with nvtx_range("recv_requests_from_prev_pp"):
|
||||
payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag)
|
||||
|
||||
# isend new requests may cause deadlock, when CUDA_LAUNCH_BLOCKING=1
|
||||
# or PP microbatches can't overlap, the deadlock will happen:
|
||||
# 1. rank1 will wait on nccl.send(rank2), without invoking mpi.wait(isend-handle)
|
||||
# 2. rank2 will wait on mpi.recv(rank1) but never receive the new requests.
|
||||
# 3. rank1 will hang on nccl.send because rank2 will never reach nccl.recv(rank1).
|
||||
pp_send_func = (
|
||||
self.dist.isend_object
|
||||
if os.environ.get("TRTLLM_PP_REQ_SEND_ASYNC", "0") == "1"
|
||||
else self.dist.send_object
|
||||
)
|
||||
|
||||
if not self.dist.is_last_pp_rank:
|
||||
if self.send_requests_handler is not None:
|
||||
with nvtx_range("wait_prev_send_requests_handler"):
|
||||
self.send_requests_handler.wait()
|
||||
with nvtx_range("send_requests_to_next_pp"):
|
||||
self.send_requests_handler = pp_send_func(payloads, self.dist.next_pp_rank, tag)
|
||||
|
||||
return payloads
|
||||
File diff suppressed because it is too large
Load Diff
182
tests/unittest/_torch/executor/test_py_executor.py
Normal file
182
tests/unittest/_torch/executor/test_py_executor.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""Tests for PyExecutor request handling functionality.
|
||||
|
||||
This module tests the request handling logic that was moved from ExecutorRequestQueue
|
||||
to PyExecutor, including:
|
||||
- _handle_special_queue_items method
|
||||
- canceled_req_ids management
|
||||
- waiting_queue management
|
||||
- is_shutdown state management
|
||||
- expected_num_active_requests tracking
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.executor_request_queue import (
|
||||
SHUTDOWN_REQUEST_ID,
|
||||
RequestQueueItem,
|
||||
)
|
||||
|
||||
|
||||
class MockPyExecutor:
|
||||
"""A mock PyExecutor class for testing request handling logic.
|
||||
|
||||
This mock contains only the attributes and methods needed to test
|
||||
the _handle_special_queue_items functionality.
|
||||
"""
|
||||
|
||||
def __init__(self, dist):
|
||||
self.dist = dist
|
||||
self.canceled_req_ids = []
|
||||
self.control_requests = []
|
||||
self.request_accumulated = []
|
||||
self.is_shutdown = False
|
||||
self.expected_num_active_requests = 0
|
||||
self.new_active_requests_queue_latency_ms = 0.0
|
||||
self.waiting_queue = deque()
|
||||
|
||||
def _handle_special_queue_items(self, new_requests):
|
||||
"""Handle special signals.
|
||||
|
||||
This method mirrors PyExecutor._handle_special_queue_items.
|
||||
"""
|
||||
accepted_new_requests = []
|
||||
for idx, req_item in enumerate(new_requests):
|
||||
if req_item.is_shutdown_request:
|
||||
self.is_shutdown = True
|
||||
break
|
||||
elif req_item.is_canceled_request:
|
||||
self.canceled_req_ids.append(req_item.id)
|
||||
elif req_item.is_control_request:
|
||||
self.control_requests.append(req_item)
|
||||
if self.dist.rank == 0:
|
||||
self.request_accumulated.extend(new_requests[idx + 1 :])
|
||||
break
|
||||
else:
|
||||
accepted_new_requests.append(req_item)
|
||||
|
||||
return accepted_new_requests
|
||||
|
||||
def update_waiting_queue(self):
|
||||
"""Update waiting queue to remove canceled requests.
|
||||
|
||||
This method mirrors PyExecutor.update_waiting_queue.
|
||||
"""
|
||||
if self.canceled_req_ids:
|
||||
canceled_set = set(self.canceled_req_ids)
|
||||
self.waiting_queue = deque(
|
||||
item for item in self.waiting_queue if item.id not in canceled_set
|
||||
)
|
||||
|
||||
def clear_canceled_req_ids(self):
|
||||
"""Clear the list of canceled request IDs."""
|
||||
self.canceled_req_ids.clear()
|
||||
|
||||
def get_canceled_req_ids(self):
|
||||
"""Get the list of canceled request IDs."""
|
||||
return self.canceled_req_ids
|
||||
|
||||
def get_canceled_req_ids_size(self):
|
||||
"""Get the number of canceled request IDs."""
|
||||
return len(self.canceled_req_ids)
|
||||
|
||||
def get_expected_num_active_requests(self):
|
||||
"""Get the expected number of active requests."""
|
||||
return self.expected_num_active_requests
|
||||
|
||||
def get_waiting_queue_size(self):
|
||||
"""Get the size of the waiting queue."""
|
||||
return len(self.waiting_queue)
|
||||
|
||||
def _get_new_active_requests_queue_latency(self):
|
||||
"""Get the queue latency for new active requests."""
|
||||
return self.new_active_requests_queue_latency_ms
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dist():
|
||||
"""Create a mock Distributed instance for testing."""
|
||||
mock_dist = Mock()
|
||||
mock_dist.rank = 0
|
||||
mock_dist.tp_size = 1
|
||||
return mock_dist
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_executor(mock_dist):
|
||||
"""Create a MockPyExecutor instance for testing."""
|
||||
return MockPyExecutor(dist=mock_dist)
|
||||
|
||||
|
||||
def test_handle_special_queue_items(mock_executor):
|
||||
"""Test special queue item handling."""
|
||||
# Create a mock request
|
||||
mock_request = Mock()
|
||||
if hasattr(mock_request, "sampling_config"):
|
||||
delattr(mock_request, "sampling_config")
|
||||
|
||||
normal_req = RequestQueueItem(1, mock_request)
|
||||
cancel_req = RequestQueueItem(2, is_canceled_request=True)
|
||||
shutdown_req = RequestQueueItem(SHUTDOWN_REQUEST_ID)
|
||||
|
||||
requests = [normal_req, cancel_req, shutdown_req]
|
||||
|
||||
valid_requests = mock_executor._handle_special_queue_items(requests)
|
||||
|
||||
assert len(valid_requests) == 1
|
||||
assert valid_requests[0] == normal_req
|
||||
assert mock_executor.is_shutdown
|
||||
assert 2 in mock_executor.canceled_req_ids
|
||||
|
||||
|
||||
def test_clear_canceled_req_ids(mock_executor):
|
||||
"""Test clearing canceled request IDs."""
|
||||
mock_executor.canceled_req_ids = [1, 2, 3]
|
||||
assert len(mock_executor.canceled_req_ids) == 3
|
||||
|
||||
mock_executor.clear_canceled_req_ids()
|
||||
|
||||
assert len(mock_executor.canceled_req_ids) == 0
|
||||
|
||||
|
||||
def test_update_waiting_queue(mock_executor):
|
||||
"""Test updating waiting queue to remove canceled requests."""
|
||||
items = [
|
||||
RequestQueueItem(1, Mock()),
|
||||
RequestQueueItem(2, Mock()),
|
||||
RequestQueueItem(3, Mock()),
|
||||
]
|
||||
mock_executor.waiting_queue.extend(items)
|
||||
mock_executor.canceled_req_ids = [2]
|
||||
|
||||
mock_executor.update_waiting_queue()
|
||||
|
||||
assert len(mock_executor.waiting_queue) == 2
|
||||
remaining_ids = [item.id for item in mock_executor.waiting_queue]
|
||||
assert 1 in remaining_ids
|
||||
assert 3 in remaining_ids
|
||||
assert 2 not in remaining_ids
|
||||
|
||||
|
||||
def test_getter_methods(mock_executor):
|
||||
"""Test various getter methods."""
|
||||
# Test initial values
|
||||
assert mock_executor._get_new_active_requests_queue_latency() == 0
|
||||
assert mock_executor.get_expected_num_active_requests() == 0
|
||||
assert mock_executor.get_canceled_req_ids_size() == 0
|
||||
assert mock_executor.get_canceled_req_ids() == []
|
||||
assert mock_executor.get_waiting_queue_size() == 0
|
||||
|
||||
# Add some data and test
|
||||
mock_executor.canceled_req_ids = [3, 4]
|
||||
mock_executor.expected_num_active_requests = 5
|
||||
mock_executor.new_active_requests_queue_latency_ms = 10.5
|
||||
mock_executor.waiting_queue.append(RequestQueueItem(1, Mock()))
|
||||
|
||||
assert mock_executor.get_canceled_req_ids_size() == 2
|
||||
assert mock_executor.get_canceled_req_ids() == [3, 4]
|
||||
assert mock_executor.get_expected_num_active_requests() == 5
|
||||
assert mock_executor._get_new_active_requests_queue_latency() == 10.5
|
||||
assert mock_executor.get_waiting_queue_size() == 1
|
||||
1103
tests/unittest/_torch/executor/test_request_utils.py
Normal file
1103
tests/unittest/_torch/executor/test_request_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user