[TRTLLM-10666][chore] Refactor request fetching logic for better separation of concerns (#10988)

Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
Signed-off-by: Lance Liao <108499334+lancelly@users.noreply.github.com>
Signed-off-by: Liao Lanyu <108499334+lancelly@users.noreply.github.com>
Co-authored-by: Lanyu Liao <lancelly@users.noreply.github.com>
This commit is contained in:
Liao Lanyu 2026-02-02 10:36:08 +08:00 committed by GitHub
parent b00e8338ec
commit fef0e4b17d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 2301 additions and 1877 deletions

View File

@ -1,24 +1,16 @@
import dataclasses
import datetime
import heapq
import os
import queue
import threading
import time
from collections import deque, namedtuple
from itertools import repeat
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
import torch
from tensorrt_llm._utils import mpi_disabled, nvtx_range
from tensorrt_llm.llmapi.disagg_utils import get_local_request_id
from tensorrt_llm.mapping import CpType
from ..distributed import Distributed
from .hang_detector import HangDetector
from .llm_request import (ExecutorRequest, LlmRequest,
executor_request_to_llm_request)
from .llm_request import ExecutorRequest
from .request_utils import get_num_child_requests
SHUTDOWN_REQUEST_ID = -1
CONTROL_REQUEST_ID = -2
@ -48,165 +40,24 @@ class RequestQueueItem:
class ExecutorRequestQueue:
"""Handles fetching and processing of new requests from the request queue."""
"""Handles basic queue operations for executor requests."""
def __init__(
self,
dist: Distributed,
enable_attention_dp: bool,
max_batch_size: int,
max_beam_width: int,
max_num_active_requests: int,
enable_iter_perf_stats: bool,
batch_wait_timeout_ms: float,
hang_detector: Optional[HangDetector] = None,
):
self.dist = dist
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
self.waiting_queue: deque[RequestQueueItem] = deque()
self.canceled_req_ids = []
self.enable_attention_dp = enable_attention_dp
self.max_batch_size = max_batch_size
self.max_beam_width = max_beam_width
self.max_num_active_requests = max_num_active_requests
self.enqueue_lock = threading.Lock()
self.next_request_id = max_batch_size
self.enable_iter_perf_stats = enable_iter_perf_stats
self.start_times = {}
self.active = True
self.batch_wait_timeout_ms = batch_wait_timeout_ms
self.send_requests_handler = None
self.hang_detector = hang_detector or HangDetector()
# State tracking
self.num_fetch_requests = 0
self.num_fetch_requests_cur_rank = 0
self.expected_num_active_requests = 0
self.new_active_requests_queue_latency_ms = 0
self.is_shutdown = False
self.should_exclude_last_generation_logits = False
self.control_requests: List[RequestQueueItem] = []
self.request_accumulated: List[RequestQueueItem] = []
self._disable_mpi = mpi_disabled()
def _get_from_request_queue(
self,
timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]:
items = []
timeout_secs = timeout.total_seconds() if timeout is not None else None
try:
if self.request_queue.empty() and (timeout_secs is None
or timeout_secs > 0):
# if queue is empty and want to wait, wait
items.append(self.request_queue.get(timeout=timeout_secs))
else:
# if not empty or don't want to wait, just return all items in queue
while True:
queue_item = self.request_queue.get_nowait()
items.append(queue_item)
except queue.Empty:
pass
if self.batch_wait_timeout_ms == 0:
return items
if len(items) >= self.max_batch_size:
return items
deadline = time.monotonic() + self.batch_wait_timeout_ms / 1000.0
while len(items) < self.max_batch_size:
remaining_timeout = deadline - time.monotonic()
if remaining_timeout <= 0:
break
try:
item = self.request_queue.get(timeout=remaining_timeout)
items.append(item)
except queue.Empty:
break
return items
@staticmethod
def _get_num_child_requests(request: ExecutorRequest) -> int:
sampling_config = request.sampling_config
return 0 if sampling_config.beam_width > 1 else (
sampling_config.num_return_sequences or 1) - 1
def _get_from_waiting_queue(
self,
waiting_queue: deque[RequestQueueItem],
max_req_count: int,
enable_attention_dp: bool,
all_ranks_num_active_requests: Optional[List[int]] = None,
) -> List[RequestQueueItem]:
"""
Args:
waiting_queue: The queue to pop items from.
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
enable_attention_dp: Whether to enable attention DP scheduling.
all_ranks_num_active_requests: Number of active requests for each rank.
Returns:
List of requests that can be processed.
"""
if max_req_count <= 0:
return []
req_count = 0
items = []
pending_requests = []
# Track the request with strict requirements
scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy(
) if enable_attention_dp else None
while req_count < max_req_count and waiting_queue:
req_item = waiting_queue[0]
num_children = len(
req_item.child_req_ids) if req_item.child_req_ids else 0
if (req_count + 1 + num_children) > max_req_count:
break
req_item = waiting_queue.popleft()
can_process = self._can_process_attention_dp_request(
req_item, scheduling_all_ranks_num_active_requests
) if enable_attention_dp else True
if can_process:
items.append(req_item)
req_count += 1 + num_children
else:
pending_requests.append(req_item)
# Put the pending requests back to the waiting queue
# All ranks should have the same waiting queue
waiting_queue.extendleft(reversed(pending_requests))
return items
def _can_process_attention_dp_request(
self, req_item: RequestQueueItem,
all_ranks_num_active_requests: List[int]) -> bool:
"""Return True if the request can be processed immediately, else False."""
scheduling_params = getattr(req_item.request, 'py_scheduling_params',
None)
if scheduling_params is None:
return True
target_dp_rank = scheduling_params.attention_dp_rank
if target_dp_rank is None or scheduling_params.attention_dp_relax:
return True
if all_ranks_num_active_requests[
target_dp_rank] < self.max_num_active_requests:
all_ranks_num_active_requests[target_dp_rank] += 1
return True
return False
def _get_request_id(self, request: Optional[ExecutorRequest] = None):
# if request has a disagg_request_id, use it as request id so that
@ -223,7 +74,7 @@ class ExecutorRequestQueue:
self, request: ExecutorRequest) -> List[int] | None:
""" Generate child request IDs if needed. """
child_req_ids = None
num_children = self._get_num_child_requests(request)
num_children = get_num_child_requests(request)
if num_children > 0:
child_req_ids = []
for _ in range(num_children):
@ -288,602 +139,53 @@ class ExecutorRequestQueue:
with self.enqueue_lock:
return self.active and self.dist.rank == 0
def _fetch_and_process_requests(
self,
total_num_active_requests: int,
total_max_num_active_requests: int,
enable_attention_dp: bool,
all_ranks_num_active_requests: Optional[List[int]] = None
) -> List[RequestQueueItem]:
"""Common logic for fetching and processing requests from the queue."""
# Block new request processing while control requests are pending.
# Control requests must be handled exclusively to ensure proper synchronization.
if len(self.control_requests) != 0:
return []
# Calculate timeout
idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0
if idle:
# In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0
# reaches the broadcast path regularly to prevent trtllm-serve timeout when idle.
timeout = datetime.timedelta(
seconds=1200) if self._disable_mpi else None
else:
timeout = datetime.timedelta(0)
# Fetch requests from rank 0
new_requests = []
if self.dist.rank == 0:
# Process accumulated requests that were queued during control request handling.
if len(self.request_accumulated) != 0:
new_requests.extend(self.request_accumulated)
self.request_accumulated.clear()
# Reset timeout to 0 to avoid hanging when no new requests are available
timeout = datetime.timedelta(0)
with self.hang_detector.pause():
new_requests.extend(self._get_from_request_queue(timeout))
# Broadcast requests and handle Python objects
new_requests, py_request_objects = self._handle_request_broadcasting(
new_requests)
# Validate and filter requests
new_requests = self._handle_special_queue_items(new_requests)
# Attach Python objects to requests
if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp
or self.dist.cp_size
> 1) and self.dist.rank > 0:
self._attach_py_objects_to_requests(new_requests,
py_request_objects)
self.waiting_queue.extend(new_requests)
new_requests = self._get_from_waiting_queue(
self.waiting_queue,
total_max_num_active_requests - total_num_active_requests,
enable_attention_dp, all_ranks_num_active_requests)
# Update performance metrics
if self.enable_iter_perf_stats and self.dist.rank == 0:
self._update_new_active_requests_queue_latency(new_requests)
return new_requests
@nvtx_range("_fetch_new_requests")
def fetch_new_requests(
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
if self.enable_attention_dp:
return self._fetch_new_requests_attention_dp(activate_requests)
else:
return self._fetch_new_requests_attention_tp(len(activate_requests))
def _fetch_new_requests_attention_tp(
self, num_active_requests: int) -> List[LlmRequest]:
"""Handle standard (non-attention DP) request fetching."""
total_num_active_requests = num_active_requests
total_max_num_active_requests = self.max_num_active_requests
# fetch and process requests into waiting queue
new_requests = self._fetch_and_process_requests(
total_num_active_requests,
total_max_num_active_requests,
enable_attention_dp=False)
# Merge requests and add to active list
merged_requests = self._merge_requests(new_requests)
return merged_requests
def _fetch_new_requests_attention_dp(
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
"""Handle attention DP request fetching with load balancing."""
# Get active request counts across all ranks.
all_ranks_num_active_requests = []
all_ranks_num_active_tokens = []
if self.dist.has_cp_helix:
num_active_tokens = sum(
[req.total_input_len_cp for req in activate_requests])
else:
num_active_tokens = sum(
[req.py_orig_prompt_len for req in activate_requests])
# Note: We use tp_allgather even for CP assuming that all CP ranks with the
# same dp_rank have the same num_active_tokens and num_active_requests.
responses_list = self.dist.tp_allgather(
[len(activate_requests), num_active_tokens])
for num_active_requests, num_active_tokens in responses_list:
all_ranks_num_active_requests.append(num_active_requests)
all_ranks_num_active_tokens.append(num_active_tokens)
total_num_active_requests = sum(all_ranks_num_active_requests)
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
# fetch and process requests into waiting queue
new_requests = self._fetch_and_process_requests(
total_num_active_requests,
total_max_num_active_requests,
enable_attention_dp=True,
all_ranks_num_active_requests=all_ranks_num_active_requests)
# Schedule attention dp requests
all_ranks_new_requests = self._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]
# Update performance metrics
if self.enable_iter_perf_stats and self.start_times:
self._update_new_active_requests_queue_latency(
new_requests_cur_rank)
# Update counters
self.num_fetch_requests += len(new_requests)
self.num_fetch_requests_cur_rank += len(new_requests_cur_rank)
# Merge requests and add to active list
new_requests_cur_rank = self._merge_requests(new_requests_cur_rank)
return new_requests_cur_rank
def _schedule_attention_dp_requests(
self, new_requests: List[RequestQueueItem],
all_ranks_num_active_requests: List[int],
all_ranks_num_active_tokens: List[int]) -> List[RequestQueueItem]:
"""Schedule attention dp requests."""
# Map from ranks to new requests
all_ranks_new_requests = {
tp_rank: []
for tp_rank in range(self.dist.tp_size)
}
# Prioritize the requests that are not in relax mode
def get_relax_value(req_item):
scheduling_params = getattr(req_item.request,
'py_scheduling_params', None)
if scheduling_params is None:
return True
return scheduling_params.attention_dp_relax
new_requests = sorted(new_requests, key=get_relax_value)
# Try to put the requests to the target dp rank until the max_num_active_requests is reached
remaining_unscheduled = []
for req_item in new_requests:
scheduled = False
scheduling_params = getattr(req_item.request,
'py_scheduling_params', None)
if scheduling_params is not None:
target_dp_rank = scheduling_params.attention_dp_rank
if target_dp_rank is not None and all_ranks_num_active_requests[
target_dp_rank] < self.max_num_active_requests:
all_ranks_num_active_requests[target_dp_rank] += 1
scheduled = True
all_ranks_new_requests[target_dp_rank].append(req_item)
if not scheduled:
remaining_unscheduled.append(req_item)
# Balance the remaining unscheduled requests across ranks
num_new_requests_all_ranks = len(remaining_unscheduled)
total_num_active_requests = sum(all_ranks_num_active_requests)
self.expected_num_active_requests = max(
(total_num_active_requests + num_new_requests_all_ranks +
self.dist.tp_size - 1) // self.dist.tp_size,
max(all_ranks_num_active_requests),
)
all_ranks_new_requests = self._balance_requests_across_ranks(
remaining_unscheduled, all_ranks_new_requests,
all_ranks_num_active_requests, all_ranks_num_active_tokens)
return all_ranks_new_requests
def _handle_request_broadcasting(self,
new_requests: List[RequestQueueItem]):
"""Handle broadcasting of requests and Python objects across ranks."""
if self.dist.rank == 0:
py_logits_post_processors = self._collect_py_objects_from_requests(
new_requests, "py_logits_post_processors")
py_multimodal_data = self._collect_py_objects_from_requests(
new_requests, "py_multimodal_data")
py_scheduling_params = self._collect_py_objects_from_requests(
new_requests, "py_scheduling_params")
py_num_logprobs = self._collect_py_objects_from_requests(
new_requests, "py_num_logprobs")
py_disaggregated_params = self._collect_py_objects_from_requests(
new_requests, "py_disaggregated_params")
py_request_objects = tuple(
filter(None, [
py_logits_post_processors, py_multimodal_data,
py_scheduling_params, py_num_logprobs,
py_disaggregated_params
]))
else:
py_request_objects = None
if self.dist.rank == 0:
# Preserve original `new_requests` on rank 0
_ = self._broadcast_new_requests(new_requests, py_request_objects)
else:
with self.hang_detector.pause():
new_requests, py_request_objects = self._broadcast_new_requests(
new_requests, py_request_objects)
return new_requests, py_request_objects
def _handle_special_queue_items(
def get_from_request_queue(
self,
new_requests: List[RequestQueueItem]) -> List[RequestQueueItem]:
"""Handle special signals."""
accepted_new_requests = []
for idx, req_item in enumerate(new_requests):
if req_item.is_shutdown_request:
self.is_shutdown = True
timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]:
"""Fetch requests from the queue with optional timeout.
Args:
timeout: Optional timeout for waiting on queue.
Returns:
List of RequestQueueItem fetched from the queue.
"""
items = []
timeout_secs = timeout.total_seconds() if timeout is not None else None
try:
if self.request_queue.empty() and (timeout_secs is None
or timeout_secs > 0):
# if queue is empty and want to wait, wait
items.append(self.request_queue.get(timeout=timeout_secs))
else:
# if not empty or don't want to wait, just return all items in queue
while True:
queue_item = self.request_queue.get_nowait()
items.append(queue_item)
except queue.Empty:
pass
if self.batch_wait_timeout_ms == 0:
return items
if len(items) >= self.max_batch_size:
return items
deadline = time.monotonic() + self.batch_wait_timeout_ms / 1000.0
while len(items) < self.max_batch_size:
remaining_timeout = deadline - time.monotonic()
if remaining_timeout <= 0:
break
elif req_item.is_canceled_request:
self.canceled_req_ids.append(req_item.id)
elif req_item.is_control_request:
self.control_requests.append(req_item)
if self.dist.rank == 0:
self.request_accumulated.extend(new_requests[idx + 1:])
try:
item = self.request_queue.get(timeout=remaining_timeout)
items.append(item)
except queue.Empty:
break
else:
accepted_new_requests.append(req_item)
return accepted_new_requests
def _balance_requests_across_ranks(
self, new_requests: List[RequestQueueItem],
all_ranks_new_requests: Dict[int, List[RequestQueueItem]],
all_ranks_num_active_requests: List[int],
all_ranks_num_active_tokens: List[int]) -> List[RequestQueueItem]:
"""Balance requests across ranks for attention DP."""
if new_requests:
# Balance context tokens across ranks using heap
HeapVal = namedtuple(
'HeapVal',
['num_tokens', 'num_requests', 'rank', 'request_list'])
all_ranks_new_requests_heap = [
HeapVal(all_ranks_num_active_tokens[tp_rank], val, tp_rank, [])
for tp_rank, val in enumerate(all_ranks_num_active_requests)
]
all_ranks_new_requests_heap = [
val for val in all_ranks_new_requests_heap
if val.num_requests < self.expected_num_active_requests
]
all_ranks_new_scheduled_requests = {
val.rank: val.request_list
for val in all_ranks_new_requests_heap
}
heapq.heapify(all_ranks_new_requests_heap)
# Sort by token count (descending) for better load balancing
new_requests = sorted(
new_requests,
key=lambda x: len(getattr(x.request, 'input_token_ids', []))
if x.request else 0,
reverse=True)
# Distribute requests across ranks
for req_item in new_requests:
val = heapq.heappop(all_ranks_new_requests_heap)
token_count = len(
getattr(req_item.request, 'input_token_ids',
[])) if req_item.request else 0
# Update the heap value with the new request
val = val._replace(
num_tokens=val.num_tokens + token_count,
num_requests=val.num_requests + 1,
)
val.request_list.append(req_item)
# If rank still has room for new requests, push back into heap
if val.num_requests < self.expected_num_active_requests:
heapq.heappush(all_ranks_new_requests_heap, val)
# Extend all_ranks_new_requests with the new requests that have been scheduled
for rank, reqs in all_ranks_new_scheduled_requests.items():
all_ranks_new_requests[rank].extend(reqs)
return all_ranks_new_requests
def _collect_py_objects_from_requests(
self, requests: List[RequestQueueItem],
attribute_name: str) -> Optional[Tuple[str, Dict]]:
"""Collect Python-only objects from requests."""
req_id_to_obj = {}
for item in requests:
if not item.is_normal_request:
continue
if item.request:
obj = getattr(item.request, attribute_name, None)
if obj is not None:
req_id_to_obj[item.id] = obj
return None if not req_id_to_obj else (attribute_name, req_id_to_obj)
@nvtx_range("_broadcast_new_requests")
def _broadcast_new_requests(
self, new_requests: List[RequestQueueItem], py_request_objects
) -> Tuple[List[RequestQueueItem], Optional[Dict]]:
"""Broadcast new_requests and optional Python-only metadata across pipeline stages."""
payloads = (new_requests, py_request_objects)
if not self.dist.has_pp:
return self.dist.broadcast(payloads, root=0)
# Broadcast within first PP stage before send/recv chain to other PP stages.
# This needs to cover both TP and CP ranks within the first PP stage.
if self.dist.is_first_pp_rank:
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
# Tag for communication
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
# Send payloads
if not self.dist.is_first_pp_rank:
with nvtx_range("recv_requests_from_prev_pp"):
payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag)
# isend new requests may cause deadlock, when CUDA_LAUNCH_BLOCKING=1 or PP microbatches can't overlap,
# the deadlock will happen deterministicly:
# 1. rank1 will wait on nccl.send(rank2), without invoking mpi.wait(isend-handle)
# 2. rank2 will wait on mpi.recv(rank1) but never receive the new requests.
# 3. rank1 will hang on nccl.send because rank2 will never reach nccl.recv(rank1).
pp_send_func = self.dist.isend_object if os.environ.get(
"TRTLLM_PP_REQ_SEND_ASYNC", "0") == "1" else self.dist.send_object
if not self.dist.is_last_pp_rank:
if self.send_requests_handler is not None:
with nvtx_range("wait_prev_send_requests_handler"):
self.send_requests_handler.wait()
with nvtx_range("send_requests_to_next_pp"):
self.send_requests_handler = pp_send_func(
payloads, self.dist.next_pp_rank, tag)
return payloads
def _attach_py_objects_to_requests(self, requests: List[RequestQueueItem],
py_request_objects) -> None:
"""Attach Python-only objects to each request."""
for attr_name, req_obj_dict in py_request_objects:
for item in requests:
if item.request:
py_obj = req_obj_dict.get(item.id)
if py_obj is not None:
setattr(item.request, attr_name, py_obj)
def _update_new_active_requests_queue_latency(
self, new_requests: List[RequestQueueItem]):
"""Update queue latency metrics for new requests."""
now = time.time()
for req_item in new_requests:
if req_item.id in self.start_times:
self.new_active_requests_queue_latency_ms += now - self.start_times.pop(
req_item.id)
if req_item.child_req_ids:
for child_id in req_item.child_req_ids:
self.new_active_requests_queue_latency_ms += now - self.start_times.pop(
child_id)
# Note: Helix parallelism is a decode-only feature run with disaggregated serving. This function gets called on gen server
# during initialization of a new request.
def _merge_helix_requests(self, new_requests: list[RequestQueueItem],
tokens_per_block: int):
req_with_children = []
num_cp_ranks = self.dist.cp_size
curr_cp_rank = self.dist.cp_rank
# For each request, partition the input_token_ids into blocks and then partition blocks across CP ranks.
# Currently, the partitioning is such that contiguous blocks are assigned to the same CP rank (as opposed
# to round-robin).
for req_item in new_requests:
all_input_ids = torch.tensor(req_item.request.input_token_ids,
dtype=torch.int64).unsqueeze(0)
input_len = all_input_ids.shape[-1]
num_total_blocks = (input_len + tokens_per_block -
1) // tokens_per_block
if num_total_blocks < num_cp_ranks:
raise ValueError(
f"There aren't enough tokens to get at least one block per CP rank. num_total_blocks {num_total_blocks} < num_cp_ranks {num_cp_ranks}. Please use smaller tokens_per_block for KV cache or reduce the number of CP ranks."
)
# Padding to ensure torch.stack used with torch.tensor_split works properly.
padding_len = 0
if input_len % tokens_per_block != 0:
padding_len = tokens_per_block - (input_len % tokens_per_block)
padding_ids = torch.zeros([1, padding_len], dtype=torch.int64)
all_input_ids = torch.cat((all_input_ids, padding_ids), dim=-1)
all_position_ids = torch.arange(0,
input_len + padding_len,
dtype=torch.int64).unsqueeze(0)
input_id_blocks_per_rank = torch.tensor_split(
torch.stack(all_input_ids.split(tokens_per_block, dim=-1)),
num_cp_ranks)
position_id_blocks_per_rank = torch.tensor_split(
torch.stack(all_position_ids.split(tokens_per_block, dim=-1)),
num_cp_ranks)
# Get the input_ids and position_ids for this rank.
input_ids_this_rank = input_id_blocks_per_rank[
curr_cp_rank].flatten().tolist()
position_ids_this_rank = position_id_blocks_per_rank[
curr_cp_rank].flatten().tolist()
# Undo the padding. Only last rank's last block will be padded right now
# given contiguous block assignment.
if curr_cp_rank == num_cp_ranks - 1 and padding_len > 0:
input_ids_this_rank = input_ids_this_rank[:-padding_len]
position_ids_this_rank = position_ids_this_rank[:-padding_len]
req = executor_request_to_llm_request(
req_id=req_item.id,
executor_request=req_item.request,
child_req_ids=req_item.child_req_ids,
exclude_last_generation_logits=self.
_should_exclude_last_generation_logits(),
input_token_ids=input_ids_this_rank,
position_ids=position_ids_this_rank,
)
req.total_input_len_cp = input_len
req.seqlen_this_rank_cp = len(input_ids_this_rank)
req_with_children.append(req)
if req.child_requests:
req_with_children.extend(req.child_requests)
return req_with_children
@nvtx_range("_merge_requests")
def _merge_requests(
self, new_requests: list[RequestQueueItem]) -> List[LlmRequest]:
cp_config = self.dist.cp_config
if 'cp_type' in cp_config:
cp_type = cp_config['cp_type']
if cp_type == CpType.STAR:
return self._merge_star_attention_requests(new_requests)
elif cp_type == CpType.HELIX:
return self._merge_helix_requests(
new_requests,
tokens_per_block=cp_config['tokens_per_block'])
else:
raise NotImplementedError(
f'Unsupported cp type {cp_type.name}.')
req_with_children = []
for req_item in new_requests:
req = executor_request_to_llm_request(
req_item.id, req_item.request, req_item.child_req_ids,
self._should_exclude_last_generation_logits())
req_with_children.append(req)
if req.child_requests:
req_with_children.extend(req.child_requests)
return req_with_children
def _merge_star_attention_requests(
self, new_requests: list[RequestQueueItem]) -> List[LlmRequest]:
result = []
for req_item in new_requests:
req_id, exe_req, query_token_ids = req_item.id, req_item.request, req_item.query
ctx_len0 = len(exe_req.input_token_ids)
ctx_blocks, position_blocks, last_block_padding_num = [
exe_req.input_token_ids
], [[i for i in range(ctx_len0)]], 0
ctx_blocks, position_blocks, last_block_padding_num = self._partition_context(
exe_req.input_token_ids)
if self.dist.cp_rank == self.dist.cp_size - 1 and last_block_padding_num > 0:
ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num]
position_blocks[-1] = position_blocks[
-1][:-last_block_padding_num]
#if has query
if query_token_ids:
ctx_blocks.append(query_token_ids)
position_blocks.append([
i for i in range(ctx_len0, ctx_len0 + len(query_token_ids))
])
# insert the dummy block to align the number of ctx iterations of each rank
block_size = self.dist.cp_config['block_size']
total_blocks = (ctx_len0 + block_size - 1) // block_size
num_blocks_per_rank = (
total_blocks + self.dist.cp_size -
1) // self.dist.cp_size + 1 # 1 for query block
if len(ctx_blocks) == num_blocks_per_rank:
ctx_blocks.insert(1, [])
position_blocks.insert(1, [])
elif len(ctx_blocks) == num_blocks_per_rank + 1:
# anchor + ctx_blocks + qry_block
pass
else:
print(
f'rank = {self.dist.cp_rank}, len(ctx_blocks) = {len(ctx_blocks) }, num_blocks_per_rank = {num_blocks_per_rank}'
)
assert False, f'invalid context partition'
# fake data for scheduler
ctx_blocks_list = [0] * (block_size +
self.dist.cp_config['cp_anchor_size'])
req = executor_request_to_llm_request(
req_id, exe_req, self._should_exclude_last_generation_logits(),
ctx_blocks_list)
req.gen_iters = 0
req.ctx_iters = 0
req.ctx_blocks = ctx_blocks
req.ctx_position_blocks = position_blocks
req.query_id = query_token_ids
result.append(req)
return result
def _partition_context(self, ctx_ids_list):
ctx_ids = torch.tensor(ctx_ids_list).unsqueeze(0)
ctx_len = ctx_ids.shape[-1]
block_size = self.dist.cp_config['block_size']
if block_size is None:
block_size = ctx_len // self.dist.cp_size
anchor_block_size = self.dist.cp_config['cp_anchor_size']
if anchor_block_size is None:
anchor_block_size = block_size
assert anchor_block_size <= block_size, f'cp_anchor_size {anchor_block_size} should be smaller than block_size {block_size}'
padding = 0
if ctx_len % block_size != 0:
padding = block_size - (ctx_len % block_size)
assert padding <= ctx_len, f'block size is too large for context, please set it smaller'
ctx_ids = torch.cat(
(ctx_ids, torch.zeros_like(ctx_ids)[:, :padding]), dim=-1)
position_ids = torch.arange(0, ctx_ids.shape[-1]).unsqueeze(0)
ctx_ids_blocks = torch.tensor_split(
torch.stack(ctx_ids.split(block_size, dim=-1)), self.dist.cp_size)
position_ids_blocks = torch.tensor_split(
torch.stack(position_ids.split(block_size, dim=-1)),
self.dist.cp_size)
if self.dist.cp_rank != 0:
ctx_blocks, position_blocks = [
ctx_ids_blocks[0][0].tolist()[0][:anchor_block_size]
], [position_ids_blocks[0][0].tolist()[0][:anchor_block_size]]
else:
ctx_blocks, position_blocks = [], []
for idx in range(len(ctx_ids_blocks[self.dist.cp_rank])):
ctx_block = ctx_ids_blocks[self.dist.cp_rank][idx]
position_block = position_ids_blocks[self.dist.cp_rank][idx]
ctx_blocks.append(ctx_block.tolist()[0])
position_blocks.append(position_block.tolist()[0])
return ctx_blocks, position_blocks, padding
def set_exclude_last_generation_logits(self,
disable_overlap_scheduler: bool,
pp_size: int) -> None:
# When overlap scheduler is enabled then when starting to handle a new prompt,
# sample_async is called twice before the first call to update_requests:
# - 1st time as a context request that handles on the 1st generated token
# - 2nd time as a generation request that handles on the 2nd generated token.
# and only after these two calls the sampler's update_request method is called.
# So in a sampler that works by the expected flow of handling the logits in
# sample_async, every update_request doesn't handle the newest token, but one
# before it. Since all these calls work on the same request object, then its
# logits storage contains the logits of both the token update_requests should work
# on, and also its next token. Thus, excluding the last generation logits from any
# getter is required.
self.should_exclude_last_generation_logits = not disable_overlap_scheduler and pp_size == 1
def _should_exclude_last_generation_logits(self) -> bool:
return self.should_exclude_last_generation_logits
def get_new_active_requests_queue_latency(self) -> float:
return self.new_active_requests_queue_latency_ms
def get_expected_num_active_requests(self) -> int:
return self.expected_num_active_requests
return items
def get_request_queue_size(self) -> int:
return self.request_queue.qsize()
@ -891,22 +193,22 @@ class ExecutorRequestQueue:
def get_request_queue(self) -> queue.Queue[RequestQueueItem]:
return self.request_queue
def get_waiting_queue(self) -> deque[RequestQueueItem]:
return self.waiting_queue
def calculate_queue_latency(self, request_items: List[RequestQueueItem],
now: float) -> float:
if not self.enable_iter_perf_stats:
return 0.0
def update_waiting_queue(self):
# Remove cancel request in the waiting queue
self.waiting_queue = deque(req for req in self.waiting_queue
if req.id not in self.canceled_req_ids)
total_latency = 0.0
def get_waiting_queue_size(self) -> int:
return len(self.waiting_queue)
for req_item in request_items:
# Handle parent request
if req_item.id in self.start_times:
total_latency += now - self.start_times.pop(req_item.id)
def get_canceled_req_ids_size(self) -> int:
return len(self.canceled_req_ids)
# Handle child requests
if req_item.child_req_ids:
for child_id in req_item.child_req_ids:
if child_id in self.start_times:
total_latency += now - self.start_times.pop(child_id)
def get_canceled_req_ids(self) -> List[int]:
return self.canceled_req_ids
def clear_canceled_req_ids(self):
self.canceled_req_ids.clear()
return total_latency

View File

@ -6,6 +6,7 @@ import pickle # nosec B403
import threading
import time
import traceback
from collections import deque
from contextlib import contextmanager
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
@ -22,7 +23,7 @@ except ImportError:
from tensorrt_llm._torch.pyexecutor.resource_manager import (
ResourceManagerType, request_context)
from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
nvtx_range, trace_func)
mpi_disabled, nvtx_range, trace_func)
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
FinishReason, InflightBatchingStats,
IterationStats, KvCacheStats,
@ -53,6 +54,9 @@ from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
LlmResponse, get_draft_token_length)
from .model_engine import ModelEngine
from .request_utils import (RequestBroadcaster, attach_py_objects_to_requests,
get_from_waiting_queue, merge_requests,
schedule_attention_dp_requests)
from .resource_manager import ResourceManager
from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState,
SampleStateTensors, TRTLLMSampler)
@ -424,16 +428,36 @@ class PyExecutor:
self._set_global_steady_clock_offset()
self.executor_request_queue = ExecutorRequestQueue(
dist=self.dist,
enable_attention_dp=self.enable_attention_dp,
max_batch_size=max_batch_size,
max_beam_width=self.max_beam_width,
max_num_active_requests=self.max_num_active_requests,
enable_iter_perf_stats=self.enable_iter_perf_stats,
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
hang_detector=self.hang_detector,
)
self.executor_request_queue.set_exclude_last_generation_logits(
self.disable_overlap_scheduler, self.dist.pp_size)
# When overlap scheduler is enabled then when starting to handle a new prompt,
# sample_async is called twice before the first call to update_requests:
# - 1st time as a context request that handles on the 1st generated token
# - 2nd time as a generation request that handles on the 2nd generated token.
# and only after these two calls the sampler's update_request method is called.
# So in a sampler that works by the expected flow of handling the logits in
# sample_async, every update_request doesn't handle the newest token, but one
# before it. Since all these calls work on the same request object, then its
# logits storage contains the logits of both the token update_requests should work
# on, and also its next token. Thus, excluding the last generation logits from any
# getter is required.
self.should_exclude_last_generation_logits = (
not self.disable_overlap_scheduler and self.dist.pp_size == 1)
# Request processing state (managed by executor)
self.canceled_req_ids: List[int] = []
self.control_requests: List[RequestQueueItem] = []
self.request_accumulated: List[RequestQueueItem] = []
self.new_active_requests_queue_latency_ms = 0.0
self._disable_mpi = mpi_disabled()
self.request_broadcaster = RequestBroadcaster(self.dist,
self.hang_detector)
# Waiting queue for requests that have been fetched but not yet scheduled
self.waiting_queue: deque[RequestQueueItem] = deque()
self.control_request_barrier = threading.Event()
self.control_action_done = threading.Event()
@ -698,7 +722,7 @@ class PyExecutor:
@property
def should_stop_processing(self):
return self.is_shutdown and len(self.active_requests) == 0 and \
self.executor_request_queue.get_waiting_queue_size() == 0
len(self.waiting_queue) == 0
@contextmanager
def _profiler(self):
@ -778,8 +802,8 @@ class PyExecutor:
f"iter = {self.iter_counter}, "
f"global_rank = {self.global_rank}, "
f"rank = {self.dist.rank}, "
f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/"
f"{self.executor_request_queue.num_fetch_requests}, "
f"currank_total_requests = {self.num_fetch_requests_cur_rank}/"
f"{self.num_fetch_requests}, "
f"host_step_time = {host_step_time}ms, "
f"prev_device_step_time = {prev_device_step_time}, "
f"timestamp = {formatted_timestamp}, "
@ -1143,8 +1167,7 @@ class PyExecutor:
if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
self.executor_request_queue.
get_new_active_requests_queue_latency())
self._get_new_active_requests_queue_latency())
self._pad_attention_dp_dummy_request()
@ -1400,8 +1423,7 @@ class PyExecutor:
if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
self.executor_request_queue.
get_new_active_requests_queue_latency())
self._get_new_active_requests_queue_latency())
self._pad_attention_dp_dummy_request()
@ -1639,14 +1661,14 @@ class PyExecutor:
def _handle_control_request(self):
if len(self.active_requests) == 0 and \
self.executor_request_queue.get_waiting_queue_size() == 0 and \
len(self.executor_request_queue.control_requests) > 0:
assert len(self.executor_request_queue.control_requests) == 1, (
len(self.waiting_queue) == 0 and \
len(self.control_requests) > 0:
assert len(self.control_requests) == 1, (
f"Expected exactly one control request to be processed at a time, "
f"but found {len(self.executor_request_queue.control_requests)} control requests. "
f"but found {len(self.control_requests)} control requests. "
f"This may indicate a race condition or improper control request handling."
)
self.executor_request_queue.control_requests.pop(0)
self.control_requests.pop(0)
self.control_request_barrier.set()
self.control_action_done.wait()
self.control_action_done.clear()
@ -1704,7 +1726,7 @@ class PyExecutor:
# to ensure consistent batch sizes for accurate performance measurement.
if not self.is_warmup and not can_forward:
if self.enable_attention_dp:
local_can_forward = self.executor_request_queue.num_fetch_requests + \
local_can_forward = self.num_fetch_requests + \
len(scheduled_batch.generation_requests) >= self.benchmark_req_queues_size
all_can_forward = self.dist.tp_allgather(
local_can_forward)
@ -1714,7 +1736,7 @@ class PyExecutor:
else:
if self.dist.rank == 0:
logger.info(
f"sleep 10 seconds, num_fetched_requests: {self.executor_request_queue.num_fetch_requests}, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}"
f"sleep 10 seconds, num_fetched_requests: {self.num_fetch_requests}, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}"
)
time.sleep(10)
continue
@ -2026,6 +2048,169 @@ class PyExecutor:
self.model_engine.model.lm_head.num_embeddings):
raise ValueError("Token ID out of range")
def _fetch_and_enqueue_requests(self,
waiting_queue: deque[RequestQueueItem],
total_num_active_requests: int) -> None:
"""Fetch requests from request_queue and enqueue to waiting_queue."""
# Block new requests while control requests are pending
if len(self.control_requests) != 0:
return
# Calculate timeout
idle = (total_num_active_requests == 0) and len(waiting_queue) == 0
if idle:
# In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0
# reaches the broadcast path regularly to prevent trtllm-serve timeout when idle.
timeout = datetime.timedelta(
seconds=1200) if self._disable_mpi else None
else:
timeout = datetime.timedelta(0)
# Fetch requests from rank 0
new_requests = []
if self.dist.rank == 0:
# Process accumulated requests that were queued during control request handling.
if len(self.request_accumulated) != 0:
new_requests.extend(self.request_accumulated)
self.request_accumulated.clear()
# Reset timeout to 0 to avoid hanging when no new requests are available
timeout = datetime.timedelta(0)
with self.hang_detector.pause():
new_requests.extend(
self.executor_request_queue.get_from_request_queue(timeout))
# Broadcast requests and handle Python objects
new_requests, py_request_objects = self.request_broadcaster.broadcast(
new_requests)
# Validate and filter requests
new_requests = self._handle_special_queue_items(new_requests)
# Attach Python objects to requests
if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp
or self.dist.cp_size
> 1) and self.dist.rank > 0:
attach_py_objects_to_requests(new_requests, py_request_objects)
waiting_queue.extend(new_requests)
def _pop_from_waiting_queue(
self,
waiting_queue: deque[RequestQueueItem],
total_num_active_requests: int,
all_ranks_num_active_requests: Optional[List[int]] = None
) -> List[RequestQueueItem]:
"""Pop requests from waiting_queue based on available capacity."""
if self.enable_attention_dp:
total_max = self.dist.tp_size * self.max_num_active_requests
else:
total_max = self.max_num_active_requests
max_new_requests = total_max - total_num_active_requests
return get_from_waiting_queue(
waiting_queue,
max_new_requests,
enable_attention_dp=self.enable_attention_dp,
max_num_active_requests=self.max_num_active_requests,
all_ranks_num_active_requests=all_ranks_num_active_requests)
@nvtx_range("_fetch_new_requests")
def _fetch_new_requests(
self, waiting_queue: deque[RequestQueueItem],
activate_requests: List[LlmRequest]) -> List[LlmRequest]:
"""Fetch new requests and return LlmRequests ready for execution."""
# 1. Gather info and calculate total_num_active_requests
if self.enable_attention_dp:
all_ranks_num_active_requests = []
all_ranks_num_active_tokens = []
if self.dist.has_cp_helix:
num_active_tokens = sum(
[req.total_input_len_cp for req in activate_requests])
else:
num_active_tokens = sum(
[req.py_orig_prompt_len for req in activate_requests])
responses_list = self.dist.tp_allgather(
[len(activate_requests), num_active_tokens])
for num_active_requests, num_active_tokens in responses_list:
all_ranks_num_active_requests.append(num_active_requests)
all_ranks_num_active_tokens.append(num_active_tokens)
total_num_active_requests = sum(all_ranks_num_active_requests)
else:
total_num_active_requests = len(activate_requests)
all_ranks_num_active_requests = None
# 2. Fetch and enqueue to waiting queue
self._fetch_and_enqueue_requests(waiting_queue,
total_num_active_requests)
# 3. Pop requests from waiting queue
new_requests = self._pop_from_waiting_queue(
waiting_queue, total_num_active_requests,
all_ranks_num_active_requests)
# 4. Update performance metrics (before DP scheduling to clear all start_times)
if self.enable_iter_perf_stats and self.dist.rank == 0:
self._update_new_active_requests_queue_latency(new_requests)
# 5. Schedule requests across ranks (DP only)
if self.enable_attention_dp:
all_ranks_new_requests, self.expected_num_active_requests = \
schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens, self.dist.tp_size,
self.max_num_active_requests)
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]
# Update counters for DP
self.num_fetch_requests += len(new_requests)
self.num_fetch_requests_cur_rank += len(new_requests_cur_rank)
new_requests = new_requests_cur_rank
# 6. Merge requests
return merge_requests(new_requests,
cp_config=self.dist.cp_config,
cp_rank=self.dist.cp_rank,
cp_size=self.dist.cp_size,
exclude_last_generation_logits=self.
_should_exclude_last_generation_logits())
def _handle_special_queue_items(
self,
new_requests: List[RequestQueueItem]) -> List[RequestQueueItem]:
"""Handle special signals."""
accepted_new_requests = []
for idx, req_item in enumerate(new_requests):
if req_item.is_shutdown_request:
self.is_shutdown = True
break
elif req_item.is_canceled_request:
self.canceled_req_ids.append(req_item.id)
elif req_item.is_control_request:
self.control_requests.append(req_item)
if self.dist.rank == 0:
self.request_accumulated.extend(new_requests[idx + 1:])
break
else:
accepted_new_requests.append(req_item)
return accepted_new_requests
def _update_new_active_requests_queue_latency(
self, new_requests: List[RequestQueueItem]):
"""Update queue latency metrics for new requests."""
now = time.time()
latency = self.executor_request_queue.calculate_queue_latency(
new_requests, now)
self.new_active_requests_queue_latency_ms += latency
def _get_new_active_requests_queue_latency(self) -> float:
return self.new_active_requests_queue_latency_ms
def _should_exclude_last_generation_logits(self) -> bool:
return self.should_exclude_last_generation_logits
def _fetch_and_activate_new_requests(self) -> List[LlmRequest]:
def _respond_if_invalid(request: LlmRequest) -> bool:
@ -2041,11 +2226,8 @@ class PyExecutor:
self._handle_errors(str(e), requests=[request])
return True
new_requests_cur_rank = self.executor_request_queue.fetch_new_requests(
self.active_requests)
self.is_shutdown = self.executor_request_queue.is_shutdown
self.expected_num_active_requests = self.executor_request_queue.get_expected_num_active_requests(
)
new_requests_cur_rank = self._fetch_new_requests(
self.waiting_queue, self.active_requests)
validated_requests = [
request for request in new_requests_cur_rank
@ -2647,20 +2829,20 @@ class PyExecutor:
@nvtx_range("_handle_canceled_requests")
def _handle_canceled_requests(self):
if self.executor_request_queue.get_canceled_req_ids_size() == 0:
if len(self.canceled_req_ids) == 0:
return
# Remove cancel request in the waiting queue
self.executor_request_queue.update_waiting_queue()
# Create set from list of canceled request ids to speed up canceled test
canceled_req_ids = set(
self.executor_request_queue.get_canceled_req_ids())
canceled_req_ids_set = set(self.canceled_req_ids)
# Remove canceled requests from the waiting queue
self.waiting_queue = deque(req for req in self.waiting_queue
if req.id not in canceled_req_ids_set)
still_pending_canceled_ids = []
for request in self.active_requests:
req_id = request.py_request_id if not request.is_child else request.parent_request_id
if req_id not in canceled_req_ids:
if req_id not in canceled_req_ids_set:
continue
is_cancelled = self._try_cancel_request(request)
@ -2673,9 +2855,8 @@ class PyExecutor:
still_pending_canceled_ids.append(req_id)
# Clear list of requests marked for cancellation and add back those that failed to cancel.
self.executor_request_queue.canceled_req_ids.clear()
self.executor_request_queue.canceled_req_ids.extend(
still_pending_canceled_ids)
self.canceled_req_ids.clear()
self.canceled_req_ids.extend(still_pending_canceled_ids)
@nvtx_range("_enqueue_responses")
def _enqueue_responses(self, responses: Iterable[Tuple[int, LlmResponse]]):

View File

@ -0,0 +1,704 @@
"""Utility functions for request processing."""
import heapq
import os
from collections import deque, namedtuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from tensorrt_llm._utils import nvtx_range
from tensorrt_llm.mapping import CpType
from ..distributed import Distributed
from .hang_detector import HangDetector
from .llm_request import ExecutorRequest, LlmRequest, executor_request_to_llm_request
# Type alias for request queue items (to avoid circular import)
# The actual RequestQueueItem class is defined in executor_request_queue.py
def get_num_child_requests(request: ExecutorRequest) -> int:
"""Get the number of child requests for a given request.
Args:
request: The executor request to check.
Returns:
Number of child requests (0 if beam search, otherwise num_return_sequences - 1).
"""
sampling_config = request.sampling_config
return 0 if sampling_config.beam_width > 1 else (sampling_config.num_return_sequences or 1) - 1
def collect_py_objects_from_requests(
requests: List, attribute_name: str
) -> Optional[Tuple[str, Dict]]:
"""Collect Python-only objects from requests.
Args:
requests: List of RequestQueueItem objects.
attribute_name: Name of the attribute to collect.
Returns:
Tuple of (attribute_name, dict mapping request_id to object) or None if empty.
"""
req_id_to_obj = {}
for item in requests:
if not item.is_normal_request:
continue
if item.request:
obj = getattr(item.request, attribute_name, None)
if obj is not None:
req_id_to_obj[item.id] = obj
return None if not req_id_to_obj else (attribute_name, req_id_to_obj)
def attach_py_objects_to_requests(requests: List, py_request_objects: Tuple) -> None:
"""Attach Python-only objects to each request.
Args:
requests: List of RequestQueueItem objects.
py_request_objects: Tuple of (attribute_name, dict) pairs.
"""
for attr_name, req_obj_dict in py_request_objects:
for item in requests:
if item.request:
py_obj = req_obj_dict.get(item.id)
if py_obj is not None:
setattr(item.request, attr_name, py_obj)
def schedule_attention_dp_requests(
new_requests: List[Any],
all_ranks_num_active_requests: List[int],
all_ranks_num_active_tokens: List[int],
tp_size: int,
max_num_active_requests: int,
) -> Tuple[Dict[int, List[Any]], int]:
"""Schedule attention DP requests across ranks.
This function distributes requests across tensor parallel ranks for attention DP.
It first tries to assign requests to their target dp_rank (if specified and has capacity),
then balances the remaining requests across all ranks.
Args:
new_requests: List of RequestQueueItem to schedule.
all_ranks_num_active_requests: Number of active requests per rank (will be modified).
all_ranks_num_active_tokens: Number of active tokens per rank.
tp_size: Number of tensor parallel ranks.
max_num_active_requests: Maximum number of active requests per rank.
Returns:
Tuple of:
- all_ranks_new_requests: Dict mapping rank to list of assigned requests.
- expected_num_active_requests: Expected number of active requests per rank.
"""
# Map from ranks to new requests
all_ranks_new_requests = {tp_rank: [] for tp_rank in range(tp_size)}
# Prioritize the requests that are not in relax mode
def get_relax_value(req_item):
scheduling_params = getattr(req_item.request, "py_scheduling_params", None)
if scheduling_params is None:
return True
return scheduling_params.attention_dp_relax
new_requests = sorted(new_requests, key=get_relax_value)
# Try to put the requests to the target dp rank until the max_num_active_requests is reached
remaining_unscheduled = []
for req_item in new_requests:
scheduled = False
scheduling_params = getattr(req_item.request, "py_scheduling_params", None)
if scheduling_params is not None:
target_dp_rank = scheduling_params.attention_dp_rank
if (
target_dp_rank is not None
and all_ranks_num_active_requests[target_dp_rank] < max_num_active_requests
):
all_ranks_num_active_requests[target_dp_rank] += 1
scheduled = True
all_ranks_new_requests[target_dp_rank].append(req_item)
if not scheduled:
remaining_unscheduled.append(req_item)
# Balance the remaining unscheduled requests across ranks
num_new_requests_all_ranks = len(remaining_unscheduled)
total_num_active_requests = sum(all_ranks_num_active_requests)
expected_num_active_requests = max(
(total_num_active_requests + num_new_requests_all_ranks + tp_size - 1) // tp_size,
max(all_ranks_num_active_requests),
)
all_ranks_new_requests = balance_requests_across_ranks(
remaining_unscheduled,
all_ranks_new_requests,
all_ranks_num_active_requests,
all_ranks_num_active_tokens,
expected_num_active_requests,
)
return all_ranks_new_requests, expected_num_active_requests
def balance_requests_across_ranks(
new_requests: List,
all_ranks_new_requests: Dict[int, List],
all_ranks_num_active_requests: List[int],
all_ranks_num_active_tokens: List[int],
expected_num_active_requests: int,
) -> Dict[int, List]:
"""Balance requests across ranks for attention DP.
Uses a heap-based algorithm to distribute requests evenly across ranks,
prioritizing ranks with fewer tokens for better load balancing.
Args:
new_requests: List of new requests to distribute.
all_ranks_new_requests: Dict mapping rank to list of already assigned requests.
all_ranks_num_active_requests: Number of active requests per rank.
all_ranks_num_active_tokens: Number of active tokens per rank.
expected_num_active_requests: Target number of active requests per rank.
Returns:
Updated all_ranks_new_requests dict with new requests distributed.
"""
if new_requests:
# Balance context tokens across ranks using heap
HeapVal = namedtuple("HeapVal", ["num_tokens", "num_requests", "rank", "request_list"])
all_ranks_new_requests_heap = [
HeapVal(all_ranks_num_active_tokens[tp_rank], val, tp_rank, [])
for tp_rank, val in enumerate(all_ranks_num_active_requests)
]
all_ranks_new_requests_heap = [
val
for val in all_ranks_new_requests_heap
if val.num_requests < expected_num_active_requests
]
all_ranks_new_scheduled_requests = {
val.rank: val.request_list for val in all_ranks_new_requests_heap
}
heapq.heapify(all_ranks_new_requests_heap)
# Sort by token count (descending) for better load balancing
new_requests = sorted(
new_requests,
key=lambda x: len(getattr(x.request, "input_token_ids", [])) if x.request else 0,
reverse=True,
)
# Distribute requests across ranks
for req_item in new_requests:
val = heapq.heappop(all_ranks_new_requests_heap)
token_count = (
len(getattr(req_item.request, "input_token_ids", [])) if req_item.request else 0
)
# Update the heap value with the new request
val = val._replace(
num_tokens=val.num_tokens + token_count,
num_requests=val.num_requests + 1,
)
val.request_list.append(req_item)
# If rank still has room for new requests, push back into heap
if val.num_requests < expected_num_active_requests:
heapq.heappush(all_ranks_new_requests_heap, val)
# Extend all_ranks_new_requests with the new requests that have been scheduled
for rank, reqs in all_ranks_new_scheduled_requests.items():
all_ranks_new_requests[rank].extend(reqs)
return all_ranks_new_requests
def can_process_attention_dp_request(
req_item, all_ranks_num_active_requests: List[int], max_num_active_requests: int
) -> bool:
"""Check if a request can be processed immediately for attention DP.
Args:
req_item: The request queue item to check.
all_ranks_num_active_requests: Number of active requests for each rank.
max_num_active_requests: Maximum number of active requests per rank.
Returns:
True if the request can be processed, False otherwise.
"""
scheduling_params = getattr(req_item.request, "py_scheduling_params", None)
if scheduling_params is None:
return True
target_dp_rank = scheduling_params.attention_dp_rank
if target_dp_rank is None or scheduling_params.attention_dp_relax:
return True
if all_ranks_num_active_requests[target_dp_rank] < max_num_active_requests:
all_ranks_num_active_requests[target_dp_rank] += 1
return True
return False
def get_from_waiting_queue(
waiting_queue: deque,
max_req_count: int,
enable_attention_dp: bool,
max_num_active_requests: int,
all_ranks_num_active_requests: Optional[List[int]] = None,
) -> List:
"""Get requests from the waiting queue.
Args:
waiting_queue: The queue to pop items from.
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
enable_attention_dp: Whether to enable attention DP scheduling.
max_num_active_requests: Maximum number of active requests per rank.
all_ranks_num_active_requests: Number of active requests for each rank.
Returns:
List of requests that can be processed.
"""
if max_req_count <= 0:
return []
req_count = 0
items = []
pending_requests = []
# Track the request with strict requirements
scheduling_all_ranks_num_active_requests = (
all_ranks_num_active_requests.copy() if enable_attention_dp else None
)
while req_count < max_req_count and waiting_queue:
req_item = waiting_queue[0]
num_children = len(req_item.child_req_ids) if req_item.child_req_ids else 0
if (req_count + 1 + num_children) > max_req_count:
break
req_item = waiting_queue.popleft()
can_process = (
can_process_attention_dp_request(
req_item, scheduling_all_ranks_num_active_requests, max_num_active_requests
)
if enable_attention_dp
else True
)
if can_process:
items.append(req_item)
req_count += 1 + num_children
else:
pending_requests.append(req_item)
# Put the pending requests back to the waiting queue
# All ranks should have the same waiting queue
waiting_queue.extendleft(reversed(pending_requests))
return items
def partition_context_for_star_attention(
ctx_ids_list: List[int], cp_rank: int, cp_size: int, block_size: int, anchor_block_size: int
) -> Tuple[List[List[int]], List[List[int]], int]:
"""Partition context for Star Attention CP.
Args:
ctx_ids_list: List of context token IDs.
cp_rank: Current CP rank.
cp_size: Total number of CP ranks.
block_size: Size of each block.
anchor_block_size: Size of anchor block.
Returns:
Tuple of (ctx_blocks, position_blocks, padding).
"""
ctx_ids = torch.tensor(ctx_ids_list).unsqueeze(0)
ctx_len = ctx_ids.shape[-1]
if block_size is None:
block_size = ctx_len // cp_size
if anchor_block_size is None:
anchor_block_size = block_size
assert anchor_block_size <= block_size, (
f"cp_anchor_size {anchor_block_size} should be smaller than block_size {block_size}"
)
padding = 0
if ctx_len % block_size != 0:
padding = block_size - (ctx_len % block_size)
assert padding <= ctx_len, "block size is too large for context, please set it smaller"
ctx_ids = torch.cat((ctx_ids, torch.zeros_like(ctx_ids)[:, :padding]), dim=-1)
position_ids = torch.arange(0, ctx_ids.shape[-1]).unsqueeze(0)
ctx_ids_blocks = torch.tensor_split(torch.stack(ctx_ids.split(block_size, dim=-1)), cp_size)
position_ids_blocks = torch.tensor_split(
torch.stack(position_ids.split(block_size, dim=-1)), cp_size
)
if cp_rank != 0:
ctx_blocks = [ctx_ids_blocks[0][0].tolist()[0][:anchor_block_size]]
position_blocks = [position_ids_blocks[0][0].tolist()[0][:anchor_block_size]]
else:
ctx_blocks, position_blocks = [], []
for idx in range(len(ctx_ids_blocks[cp_rank])):
ctx_block = ctx_ids_blocks[cp_rank][idx]
position_block = position_ids_blocks[cp_rank][idx]
ctx_blocks.append(ctx_block.tolist()[0])
position_blocks.append(position_block.tolist()[0])
return ctx_blocks, position_blocks, padding
def partition_context_for_helix(
input_token_ids: List[int], cp_rank: int, cp_size: int, tokens_per_block: int
) -> Tuple[List[int], List[int], int, int]:
"""Partition context for Helix CP.
Args:
input_token_ids: List of input token IDs.
cp_rank: Current CP rank.
cp_size: Total number of CP ranks.
tokens_per_block: Number of tokens per block.
Returns:
Tuple of (input_ids_this_rank, position_ids_this_rank, input_len, padding_len).
Raises:
ValueError: If there aren't enough tokens for at least one block per CP rank.
"""
all_input_ids = torch.tensor(input_token_ids, dtype=torch.int64).unsqueeze(0)
input_len = all_input_ids.shape[-1]
num_total_blocks = (input_len + tokens_per_block - 1) // tokens_per_block
if num_total_blocks < cp_size:
raise ValueError(
f"There aren't enough tokens to get at least one block per CP rank. "
f"num_total_blocks {num_total_blocks} < num_cp_ranks {cp_size}. "
f"Please use smaller tokens_per_block for KV cache or reduce the number of CP ranks."
)
# Padding to ensure torch.stack used with torch.tensor_split works properly.
padding_len = 0
if input_len % tokens_per_block != 0:
padding_len = tokens_per_block - (input_len % tokens_per_block)
padding_ids = torch.zeros([1, padding_len], dtype=torch.int64)
all_input_ids = torch.cat((all_input_ids, padding_ids), dim=-1)
all_position_ids = torch.arange(0, input_len + padding_len, dtype=torch.int64).unsqueeze(0)
input_id_blocks_per_rank = torch.tensor_split(
torch.stack(all_input_ids.split(tokens_per_block, dim=-1)), cp_size
)
position_id_blocks_per_rank = torch.tensor_split(
torch.stack(all_position_ids.split(tokens_per_block, dim=-1)), cp_size
)
# Get the input_ids and position_ids for this rank.
input_ids_this_rank = input_id_blocks_per_rank[cp_rank].flatten().tolist()
position_ids_this_rank = position_id_blocks_per_rank[cp_rank].flatten().tolist()
# Undo the padding. Only last rank's last block will be padded right now
# given contiguous block assignment.
if cp_rank == cp_size - 1 and padding_len > 0:
input_ids_this_rank = input_ids_this_rank[:-padding_len]
position_ids_this_rank = position_ids_this_rank[:-padding_len]
return input_ids_this_rank, position_ids_this_rank, input_len, padding_len
def merge_requests_to_llm_requests(
new_requests: List, exclude_last_generation_logits: bool
) -> List[LlmRequest]:
"""Merge RequestQueueItems to LlmRequests (basic case without CP).
Args:
new_requests: List of RequestQueueItem objects.
exclude_last_generation_logits: Whether to exclude last generation logits.
Returns:
List of LlmRequest objects including child requests.
"""
req_with_children = []
for req_item in new_requests:
req = executor_request_to_llm_request(
req_item.id, req_item.request, req_item.child_req_ids, exclude_last_generation_logits
)
req_with_children.append(req)
if req.child_requests:
req_with_children.extend(req.child_requests)
return req_with_children
def merge_helix_requests(
new_requests: List,
cp_rank: int,
cp_size: int,
tokens_per_block: int,
exclude_last_generation_logits: bool,
) -> List[LlmRequest]:
"""Merge requests for Helix CP.
Note: Helix parallelism is a decode-only feature run with disaggregated serving.
This function gets called on gen server during initialization of a new request.
Args:
new_requests: List of RequestQueueItem objects.
cp_rank: Current CP rank.
cp_size: Total number of CP ranks.
tokens_per_block: Number of tokens per block.
exclude_last_generation_logits: Whether to exclude last generation logits.
Returns:
List of LlmRequest objects including child requests.
"""
req_with_children = []
for req_item in new_requests:
input_ids_this_rank, position_ids_this_rank, input_len, _ = partition_context_for_helix(
req_item.request.input_token_ids, cp_rank, cp_size, tokens_per_block
)
req = executor_request_to_llm_request(
req_id=req_item.id,
executor_request=req_item.request,
child_req_ids=req_item.child_req_ids,
exclude_last_generation_logits=exclude_last_generation_logits,
input_token_ids=input_ids_this_rank,
position_ids=position_ids_this_rank,
)
req.total_input_len_cp = input_len
req.seqlen_this_rank_cp = len(input_ids_this_rank)
req_with_children.append(req)
if req.child_requests:
req_with_children.extend(req.child_requests)
return req_with_children
def merge_star_attention_requests(
new_requests: List,
cp_rank: int,
cp_size: int,
cp_config: dict,
exclude_last_generation_logits: bool,
) -> List[LlmRequest]:
"""Merge requests for Star Attention CP.
Args:
new_requests: List of RequestQueueItem objects.
cp_rank: Current CP rank.
cp_size: Total number of CP ranks.
cp_config: CP configuration dict containing 'block_size' and 'cp_anchor_size'.
exclude_last_generation_logits: Whether to exclude last generation logits.
Returns:
List of LlmRequest objects.
"""
result = []
block_size = cp_config["block_size"]
anchor_block_size = cp_config["cp_anchor_size"]
for req_item in new_requests:
req_id, exe_req, query_token_ids = req_item.id, req_item.request, req_item.query
ctx_len0 = len(exe_req.input_token_ids)
ctx_blocks, position_blocks, last_block_padding_num = partition_context_for_star_attention(
exe_req.input_token_ids, cp_rank, cp_size, block_size, anchor_block_size
)
if cp_rank == cp_size - 1 and last_block_padding_num > 0:
ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num]
position_blocks[-1] = position_blocks[-1][:-last_block_padding_num]
# if has query
if query_token_ids:
ctx_blocks.append(query_token_ids)
position_blocks.append([i for i in range(ctx_len0, ctx_len0 + len(query_token_ids))])
# insert the dummy block to align the number of ctx iterations of each rank
total_blocks = (ctx_len0 + block_size - 1) // block_size
num_blocks_per_rank = (total_blocks + cp_size - 1) // cp_size + 1 # 1 for query block
if len(ctx_blocks) == num_blocks_per_rank:
ctx_blocks.insert(1, [])
position_blocks.insert(1, [])
elif len(ctx_blocks) == num_blocks_per_rank + 1:
# anchor + ctx_blocks + qry_block
pass
else:
raise ValueError(
f"Invalid context partition: rank = {cp_rank}, "
f"len(ctx_blocks) = {len(ctx_blocks)}, "
f"num_blocks_per_rank = {num_blocks_per_rank}"
)
# fake data for scheduler
ctx_blocks_list = [0] * (block_size + anchor_block_size)
req = executor_request_to_llm_request(
req_id, exe_req, exclude_last_generation_logits, ctx_blocks_list
)
req.gen_iters = 0
req.ctx_iters = 0
req.ctx_blocks = ctx_blocks
req.ctx_position_blocks = position_blocks
req.query_id = query_token_ids
result.append(req)
return result
@nvtx_range("merge_requests")
def merge_requests(
new_requests: List,
cp_config: dict,
cp_rank: int,
cp_size: int,
exclude_last_generation_logits: bool,
) -> List[LlmRequest]:
"""Merge RequestQueueItems to LlmRequests based on CP configuration.
This is a router function that dispatches to the appropriate merge function
based on the CP (Context Parallelism) configuration.
Args:
new_requests: List of RequestQueueItem objects.
cp_config: CP configuration dict. May contain 'cp_type', 'tokens_per_block',
'block_size', 'cp_anchor_size'.
cp_rank: Current CP rank.
cp_size: Total number of CP ranks.
exclude_last_generation_logits: Whether to exclude last generation logits.
Returns:
List of LlmRequest objects.
Raises:
NotImplementedError: If cp_type is not supported.
"""
if "cp_type" in cp_config:
cp_type = cp_config["cp_type"]
if cp_type == CpType.STAR:
return merge_star_attention_requests(
new_requests,
cp_rank=cp_rank,
cp_size=cp_size,
cp_config=cp_config,
exclude_last_generation_logits=exclude_last_generation_logits,
)
elif cp_type == CpType.HELIX:
return merge_helix_requests(
new_requests,
cp_rank=cp_rank,
cp_size=cp_size,
tokens_per_block=cp_config["tokens_per_block"],
exclude_last_generation_logits=exclude_last_generation_logits,
)
else:
raise NotImplementedError(f"Unsupported cp type {cp_type.name}.")
return merge_requests_to_llm_requests(new_requests, exclude_last_generation_logits)
class RequestBroadcaster:
"""Broadcasts requests across distributed ranks (TP, PP, CP)."""
def __init__(self, dist: Distributed, hang_detector: HangDetector):
self.dist = dist
self.hang_detector = hang_detector
self.send_requests_handler = None
def broadcast(self, new_requests: List) -> Tuple[List, Optional[Tuple]]:
"""Broadcast requests and Python objects across ranks."""
if self.dist.rank == 0:
py_request_objects = self._collect_py_objects(new_requests)
else:
py_request_objects = None
if self.dist.rank == 0:
# Preserve original `new_requests` on rank 0
_ = self._broadcast_requests(new_requests, py_request_objects)
else:
with self.hang_detector.pause():
new_requests, py_request_objects = self._broadcast_requests(
new_requests, py_request_objects
)
return new_requests, py_request_objects
def _collect_py_objects(self, new_requests: List) -> Tuple:
"""Collect Python-only objects from requests."""
py_logits_post_processors = collect_py_objects_from_requests(
new_requests, "py_logits_post_processors"
)
py_multimodal_data = collect_py_objects_from_requests(new_requests, "py_multimodal_data")
py_scheduling_params = collect_py_objects_from_requests(
new_requests, "py_scheduling_params"
)
py_num_logprobs = collect_py_objects_from_requests(new_requests, "py_num_logprobs")
py_disaggregated_params = collect_py_objects_from_requests(
new_requests, "py_disaggregated_params"
)
return tuple(
filter(
None,
[
py_logits_post_processors,
py_multimodal_data,
py_scheduling_params,
py_num_logprobs,
py_disaggregated_params,
],
)
)
@nvtx_range("broadcast_requests")
def _broadcast_requests(
self, new_requests: List, py_request_objects
) -> Tuple[List, Optional[Dict]]:
"""Broadcast requests across pipeline stages."""
payloads = (new_requests, py_request_objects)
if not self.dist.has_pp:
return self.dist.broadcast(payloads, root=0)
# Broadcast within first PP stage before send/recv chain to other PP stages.
# This needs to cover both TP and CP ranks within the first PP stage.
if self.dist.is_first_pp_rank:
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
# Tag for communication
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
# Send payloads
if not self.dist.is_first_pp_rank:
with nvtx_range("recv_requests_from_prev_pp"):
payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag)
# isend new requests may cause deadlock, when CUDA_LAUNCH_BLOCKING=1
# or PP microbatches can't overlap, the deadlock will happen:
# 1. rank1 will wait on nccl.send(rank2), without invoking mpi.wait(isend-handle)
# 2. rank2 will wait on mpi.recv(rank1) but never receive the new requests.
# 3. rank1 will hang on nccl.send because rank2 will never reach nccl.recv(rank1).
pp_send_func = (
self.dist.isend_object
if os.environ.get("TRTLLM_PP_REQ_SEND_ASYNC", "0") == "1"
else self.dist.send_object
)
if not self.dist.is_last_pp_rank:
if self.send_requests_handler is not None:
with nvtx_range("wait_prev_send_requests_handler"):
self.send_requests_handler.wait()
with nvtx_range("send_requests_to_next_pp"):
self.send_requests_handler = pp_send_func(payloads, self.dist.next_pp_rank, tag)
return payloads

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,182 @@
"""Tests for PyExecutor request handling functionality.
This module tests the request handling logic that was moved from ExecutorRequestQueue
to PyExecutor, including:
- _handle_special_queue_items method
- canceled_req_ids management
- waiting_queue management
- is_shutdown state management
- expected_num_active_requests tracking
"""
from collections import deque
from unittest.mock import Mock
import pytest
from tensorrt_llm._torch.pyexecutor.executor_request_queue import (
SHUTDOWN_REQUEST_ID,
RequestQueueItem,
)
class MockPyExecutor:
"""A mock PyExecutor class for testing request handling logic.
This mock contains only the attributes and methods needed to test
the _handle_special_queue_items functionality.
"""
def __init__(self, dist):
self.dist = dist
self.canceled_req_ids = []
self.control_requests = []
self.request_accumulated = []
self.is_shutdown = False
self.expected_num_active_requests = 0
self.new_active_requests_queue_latency_ms = 0.0
self.waiting_queue = deque()
def _handle_special_queue_items(self, new_requests):
"""Handle special signals.
This method mirrors PyExecutor._handle_special_queue_items.
"""
accepted_new_requests = []
for idx, req_item in enumerate(new_requests):
if req_item.is_shutdown_request:
self.is_shutdown = True
break
elif req_item.is_canceled_request:
self.canceled_req_ids.append(req_item.id)
elif req_item.is_control_request:
self.control_requests.append(req_item)
if self.dist.rank == 0:
self.request_accumulated.extend(new_requests[idx + 1 :])
break
else:
accepted_new_requests.append(req_item)
return accepted_new_requests
def update_waiting_queue(self):
"""Update waiting queue to remove canceled requests.
This method mirrors PyExecutor.update_waiting_queue.
"""
if self.canceled_req_ids:
canceled_set = set(self.canceled_req_ids)
self.waiting_queue = deque(
item for item in self.waiting_queue if item.id not in canceled_set
)
def clear_canceled_req_ids(self):
"""Clear the list of canceled request IDs."""
self.canceled_req_ids.clear()
def get_canceled_req_ids(self):
"""Get the list of canceled request IDs."""
return self.canceled_req_ids
def get_canceled_req_ids_size(self):
"""Get the number of canceled request IDs."""
return len(self.canceled_req_ids)
def get_expected_num_active_requests(self):
"""Get the expected number of active requests."""
return self.expected_num_active_requests
def get_waiting_queue_size(self):
"""Get the size of the waiting queue."""
return len(self.waiting_queue)
def _get_new_active_requests_queue_latency(self):
"""Get the queue latency for new active requests."""
return self.new_active_requests_queue_latency_ms
@pytest.fixture
def mock_dist():
"""Create a mock Distributed instance for testing."""
mock_dist = Mock()
mock_dist.rank = 0
mock_dist.tp_size = 1
return mock_dist
@pytest.fixture
def mock_executor(mock_dist):
"""Create a MockPyExecutor instance for testing."""
return MockPyExecutor(dist=mock_dist)
def test_handle_special_queue_items(mock_executor):
"""Test special queue item handling."""
# Create a mock request
mock_request = Mock()
if hasattr(mock_request, "sampling_config"):
delattr(mock_request, "sampling_config")
normal_req = RequestQueueItem(1, mock_request)
cancel_req = RequestQueueItem(2, is_canceled_request=True)
shutdown_req = RequestQueueItem(SHUTDOWN_REQUEST_ID)
requests = [normal_req, cancel_req, shutdown_req]
valid_requests = mock_executor._handle_special_queue_items(requests)
assert len(valid_requests) == 1
assert valid_requests[0] == normal_req
assert mock_executor.is_shutdown
assert 2 in mock_executor.canceled_req_ids
def test_clear_canceled_req_ids(mock_executor):
"""Test clearing canceled request IDs."""
mock_executor.canceled_req_ids = [1, 2, 3]
assert len(mock_executor.canceled_req_ids) == 3
mock_executor.clear_canceled_req_ids()
assert len(mock_executor.canceled_req_ids) == 0
def test_update_waiting_queue(mock_executor):
"""Test updating waiting queue to remove canceled requests."""
items = [
RequestQueueItem(1, Mock()),
RequestQueueItem(2, Mock()),
RequestQueueItem(3, Mock()),
]
mock_executor.waiting_queue.extend(items)
mock_executor.canceled_req_ids = [2]
mock_executor.update_waiting_queue()
assert len(mock_executor.waiting_queue) == 2
remaining_ids = [item.id for item in mock_executor.waiting_queue]
assert 1 in remaining_ids
assert 3 in remaining_ids
assert 2 not in remaining_ids
def test_getter_methods(mock_executor):
"""Test various getter methods."""
# Test initial values
assert mock_executor._get_new_active_requests_queue_latency() == 0
assert mock_executor.get_expected_num_active_requests() == 0
assert mock_executor.get_canceled_req_ids_size() == 0
assert mock_executor.get_canceled_req_ids() == []
assert mock_executor.get_waiting_queue_size() == 0
# Add some data and test
mock_executor.canceled_req_ids = [3, 4]
mock_executor.expected_num_active_requests = 5
mock_executor.new_active_requests_queue_latency_ms = 10.5
mock_executor.waiting_queue.append(RequestQueueItem(1, Mock()))
assert mock_executor.get_canceled_req_ids_size() == 2
assert mock_executor.get_canceled_req_ids() == [3, 4]
assert mock_executor.get_expected_num_active_requests() == 5
assert mock_executor._get_new_active_requests_queue_latency() == 10.5
assert mock_executor.get_waiting_queue_size() == 1

File diff suppressed because it is too large Load Diff