diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index dd8633411c..f7286f2a2c 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -1,24 +1,16 @@ import dataclasses import datetime -import heapq -import os import queue import threading import time -from collections import deque, namedtuple from itertools import repeat -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple -import torch - -from tensorrt_llm._utils import mpi_disabled, nvtx_range from tensorrt_llm.llmapi.disagg_utils import get_local_request_id -from tensorrt_llm.mapping import CpType from ..distributed import Distributed -from .hang_detector import HangDetector -from .llm_request import (ExecutorRequest, LlmRequest, - executor_request_to_llm_request) +from .llm_request import ExecutorRequest +from .request_utils import get_num_child_requests SHUTDOWN_REQUEST_ID = -1 CONTROL_REQUEST_ID = -2 @@ -48,165 +40,24 @@ class RequestQueueItem: class ExecutorRequestQueue: - """Handles fetching and processing of new requests from the request queue.""" + """Handles basic queue operations for executor requests.""" def __init__( self, dist: Distributed, - enable_attention_dp: bool, max_batch_size: int, - max_beam_width: int, - max_num_active_requests: int, enable_iter_perf_stats: bool, batch_wait_timeout_ms: float, - hang_detector: Optional[HangDetector] = None, ): self.dist = dist self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue() - self.waiting_queue: deque[RequestQueueItem] = deque() - self.canceled_req_ids = [] - self.enable_attention_dp = enable_attention_dp self.max_batch_size = max_batch_size - self.max_beam_width = max_beam_width - self.max_num_active_requests = max_num_active_requests self.enqueue_lock = threading.Lock() self.next_request_id = max_batch_size self.enable_iter_perf_stats = enable_iter_perf_stats self.start_times = {} self.active = True self.batch_wait_timeout_ms = batch_wait_timeout_ms - self.send_requests_handler = None - self.hang_detector = hang_detector or HangDetector() - - # State tracking - self.num_fetch_requests = 0 - self.num_fetch_requests_cur_rank = 0 - self.expected_num_active_requests = 0 - self.new_active_requests_queue_latency_ms = 0 - self.is_shutdown = False - self.should_exclude_last_generation_logits = False - self.control_requests: List[RequestQueueItem] = [] - self.request_accumulated: List[RequestQueueItem] = [] - - self._disable_mpi = mpi_disabled() - - def _get_from_request_queue( - self, - timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]: - - items = [] - timeout_secs = timeout.total_seconds() if timeout is not None else None - - try: - if self.request_queue.empty() and (timeout_secs is None - or timeout_secs > 0): - # if queue is empty and want to wait, wait - items.append(self.request_queue.get(timeout=timeout_secs)) - else: - # if not empty or don't want to wait, just return all items in queue - while True: - queue_item = self.request_queue.get_nowait() - items.append(queue_item) - except queue.Empty: - pass - - if self.batch_wait_timeout_ms == 0: - return items - - if len(items) >= self.max_batch_size: - return items - - deadline = time.monotonic() + self.batch_wait_timeout_ms / 1000.0 - while len(items) < self.max_batch_size: - remaining_timeout = deadline - time.monotonic() - - if remaining_timeout <= 0: - break - - try: - item = self.request_queue.get(timeout=remaining_timeout) - items.append(item) - except queue.Empty: - break - - return items - - @staticmethod - def _get_num_child_requests(request: ExecutorRequest) -> int: - sampling_config = request.sampling_config - return 0 if sampling_config.beam_width > 1 else ( - sampling_config.num_return_sequences or 1) - 1 - - def _get_from_waiting_queue( - self, - waiting_queue: deque[RequestQueueItem], - max_req_count: int, - enable_attention_dp: bool, - all_ranks_num_active_requests: Optional[List[int]] = None, - ) -> List[RequestQueueItem]: - """ - Args: - waiting_queue: The queue to pop items from. - max_req_count: Maximum items to retrieve. Returns empty list if <=0. - enable_attention_dp: Whether to enable attention DP scheduling. - all_ranks_num_active_requests: Number of active requests for each rank. - Returns: - List of requests that can be processed. - """ - - if max_req_count <= 0: - return [] - - req_count = 0 - items = [] - pending_requests = [] - - # Track the request with strict requirements - scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy( - ) if enable_attention_dp else None - while req_count < max_req_count and waiting_queue: - req_item = waiting_queue[0] - num_children = len( - req_item.child_req_ids) if req_item.child_req_ids else 0 - if (req_count + 1 + num_children) > max_req_count: - break - req_item = waiting_queue.popleft() - can_process = self._can_process_attention_dp_request( - req_item, scheduling_all_ranks_num_active_requests - ) if enable_attention_dp else True - - if can_process: - items.append(req_item) - req_count += 1 + num_children - else: - pending_requests.append(req_item) - - # Put the pending requests back to the waiting queue - # All ranks should have the same waiting queue - waiting_queue.extendleft(reversed(pending_requests)) - - return items - - def _can_process_attention_dp_request( - self, req_item: RequestQueueItem, - all_ranks_num_active_requests: List[int]) -> bool: - """Return True if the request can be processed immediately, else False.""" - - scheduling_params = getattr(req_item.request, 'py_scheduling_params', - None) - if scheduling_params is None: - return True - - target_dp_rank = scheduling_params.attention_dp_rank - if target_dp_rank is None or scheduling_params.attention_dp_relax: - return True - - if all_ranks_num_active_requests[ - target_dp_rank] < self.max_num_active_requests: - all_ranks_num_active_requests[target_dp_rank] += 1 - return True - - return False def _get_request_id(self, request: Optional[ExecutorRequest] = None): # if request has a disagg_request_id, use it as request id so that @@ -223,7 +74,7 @@ class ExecutorRequestQueue: self, request: ExecutorRequest) -> List[int] | None: """ Generate child request IDs if needed. """ child_req_ids = None - num_children = self._get_num_child_requests(request) + num_children = get_num_child_requests(request) if num_children > 0: child_req_ids = [] for _ in range(num_children): @@ -288,602 +139,53 @@ class ExecutorRequestQueue: with self.enqueue_lock: return self.active and self.dist.rank == 0 - def _fetch_and_process_requests( - self, - total_num_active_requests: int, - total_max_num_active_requests: int, - enable_attention_dp: bool, - all_ranks_num_active_requests: Optional[List[int]] = None - ) -> List[RequestQueueItem]: - """Common logic for fetching and processing requests from the queue.""" - # Block new request processing while control requests are pending. - # Control requests must be handled exclusively to ensure proper synchronization. - if len(self.control_requests) != 0: - return [] - # Calculate timeout - idle = (total_num_active_requests == 0) and len(self.waiting_queue) == 0 - if idle: - # In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0 - # reaches the broadcast path regularly to prevent trtllm-serve timeout when idle. - timeout = datetime.timedelta( - seconds=1200) if self._disable_mpi else None - else: - timeout = datetime.timedelta(0) - - # Fetch requests from rank 0 - new_requests = [] - if self.dist.rank == 0: - # Process accumulated requests that were queued during control request handling. - if len(self.request_accumulated) != 0: - new_requests.extend(self.request_accumulated) - self.request_accumulated.clear() - # Reset timeout to 0 to avoid hanging when no new requests are available - timeout = datetime.timedelta(0) - with self.hang_detector.pause(): - new_requests.extend(self._get_from_request_queue(timeout)) - - # Broadcast requests and handle Python objects - new_requests, py_request_objects = self._handle_request_broadcasting( - new_requests) - - # Validate and filter requests - new_requests = self._handle_special_queue_items(new_requests) - - # Attach Python objects to requests - if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp - or self.dist.cp_size - > 1) and self.dist.rank > 0: - self._attach_py_objects_to_requests(new_requests, - py_request_objects) - self.waiting_queue.extend(new_requests) - - new_requests = self._get_from_waiting_queue( - self.waiting_queue, - total_max_num_active_requests - total_num_active_requests, - enable_attention_dp, all_ranks_num_active_requests) - - # Update performance metrics - if self.enable_iter_perf_stats and self.dist.rank == 0: - self._update_new_active_requests_queue_latency(new_requests) - - return new_requests - - @nvtx_range("_fetch_new_requests") - def fetch_new_requests( - self, activate_requests: List[LlmRequest]) -> List[LlmRequest]: - - if self.enable_attention_dp: - return self._fetch_new_requests_attention_dp(activate_requests) - else: - return self._fetch_new_requests_attention_tp(len(activate_requests)) - - def _fetch_new_requests_attention_tp( - self, num_active_requests: int) -> List[LlmRequest]: - """Handle standard (non-attention DP) request fetching.""" - total_num_active_requests = num_active_requests - total_max_num_active_requests = self.max_num_active_requests - - # fetch and process requests into waiting queue - new_requests = self._fetch_and_process_requests( - total_num_active_requests, - total_max_num_active_requests, - enable_attention_dp=False) - - # Merge requests and add to active list - merged_requests = self._merge_requests(new_requests) - return merged_requests - - def _fetch_new_requests_attention_dp( - self, activate_requests: List[LlmRequest]) -> List[LlmRequest]: - """Handle attention DP request fetching with load balancing.""" - # Get active request counts across all ranks. - all_ranks_num_active_requests = [] - all_ranks_num_active_tokens = [] - - if self.dist.has_cp_helix: - num_active_tokens = sum( - [req.total_input_len_cp for req in activate_requests]) - else: - num_active_tokens = sum( - [req.py_orig_prompt_len for req in activate_requests]) - - # Note: We use tp_allgather even for CP assuming that all CP ranks with the - # same dp_rank have the same num_active_tokens and num_active_requests. - responses_list = self.dist.tp_allgather( - [len(activate_requests), num_active_tokens]) - - for num_active_requests, num_active_tokens in responses_list: - all_ranks_num_active_requests.append(num_active_requests) - all_ranks_num_active_tokens.append(num_active_tokens) - - total_num_active_requests = sum(all_ranks_num_active_requests) - total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests - - # fetch and process requests into waiting queue - new_requests = self._fetch_and_process_requests( - total_num_active_requests, - total_max_num_active_requests, - enable_attention_dp=True, - all_ranks_num_active_requests=all_ranks_num_active_requests) - - # Schedule attention dp requests - all_ranks_new_requests = self._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank] - - # Update performance metrics - if self.enable_iter_perf_stats and self.start_times: - self._update_new_active_requests_queue_latency( - new_requests_cur_rank) - - # Update counters - self.num_fetch_requests += len(new_requests) - self.num_fetch_requests_cur_rank += len(new_requests_cur_rank) - - # Merge requests and add to active list - new_requests_cur_rank = self._merge_requests(new_requests_cur_rank) - return new_requests_cur_rank - - def _schedule_attention_dp_requests( - self, new_requests: List[RequestQueueItem], - all_ranks_num_active_requests: List[int], - all_ranks_num_active_tokens: List[int]) -> List[RequestQueueItem]: - """Schedule attention dp requests.""" - - # Map from ranks to new requests - all_ranks_new_requests = { - tp_rank: [] - for tp_rank in range(self.dist.tp_size) - } - - # Prioritize the requests that are not in relax mode - def get_relax_value(req_item): - scheduling_params = getattr(req_item.request, - 'py_scheduling_params', None) - if scheduling_params is None: - return True - return scheduling_params.attention_dp_relax - - new_requests = sorted(new_requests, key=get_relax_value) - - # Try to put the requests to the target dp rank until the max_num_active_requests is reached - remaining_unscheduled = [] - for req_item in new_requests: - scheduled = False - scheduling_params = getattr(req_item.request, - 'py_scheduling_params', None) - if scheduling_params is not None: - target_dp_rank = scheduling_params.attention_dp_rank - if target_dp_rank is not None and all_ranks_num_active_requests[ - target_dp_rank] < self.max_num_active_requests: - all_ranks_num_active_requests[target_dp_rank] += 1 - scheduled = True - all_ranks_new_requests[target_dp_rank].append(req_item) - - if not scheduled: - remaining_unscheduled.append(req_item) - - # Balance the remaining unscheduled requests across ranks - num_new_requests_all_ranks = len(remaining_unscheduled) - total_num_active_requests = sum(all_ranks_num_active_requests) - self.expected_num_active_requests = max( - (total_num_active_requests + num_new_requests_all_ranks + - self.dist.tp_size - 1) // self.dist.tp_size, - max(all_ranks_num_active_requests), - ) - - all_ranks_new_requests = self._balance_requests_across_ranks( - remaining_unscheduled, all_ranks_new_requests, - all_ranks_num_active_requests, all_ranks_num_active_tokens) - - return all_ranks_new_requests - - def _handle_request_broadcasting(self, - new_requests: List[RequestQueueItem]): - """Handle broadcasting of requests and Python objects across ranks.""" - if self.dist.rank == 0: - py_logits_post_processors = self._collect_py_objects_from_requests( - new_requests, "py_logits_post_processors") - py_multimodal_data = self._collect_py_objects_from_requests( - new_requests, "py_multimodal_data") - py_scheduling_params = self._collect_py_objects_from_requests( - new_requests, "py_scheduling_params") - py_num_logprobs = self._collect_py_objects_from_requests( - new_requests, "py_num_logprobs") - py_disaggregated_params = self._collect_py_objects_from_requests( - new_requests, "py_disaggregated_params") - py_request_objects = tuple( - filter(None, [ - py_logits_post_processors, py_multimodal_data, - py_scheduling_params, py_num_logprobs, - py_disaggregated_params - ])) - else: - py_request_objects = None - - if self.dist.rank == 0: - # Preserve original `new_requests` on rank 0 - _ = self._broadcast_new_requests(new_requests, py_request_objects) - else: - with self.hang_detector.pause(): - new_requests, py_request_objects = self._broadcast_new_requests( - new_requests, py_request_objects) - - return new_requests, py_request_objects - - def _handle_special_queue_items( + def get_from_request_queue( self, - new_requests: List[RequestQueueItem]) -> List[RequestQueueItem]: - """Handle special signals.""" - accepted_new_requests = [] - for idx, req_item in enumerate(new_requests): - if req_item.is_shutdown_request: - self.is_shutdown = True + timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]: + """Fetch requests from the queue with optional timeout. + + Args: + timeout: Optional timeout for waiting on queue. + + Returns: + List of RequestQueueItem fetched from the queue. + """ + items = [] + timeout_secs = timeout.total_seconds() if timeout is not None else None + + try: + if self.request_queue.empty() and (timeout_secs is None + or timeout_secs > 0): + # if queue is empty and want to wait, wait + items.append(self.request_queue.get(timeout=timeout_secs)) + else: + # if not empty or don't want to wait, just return all items in queue + while True: + queue_item = self.request_queue.get_nowait() + items.append(queue_item) + except queue.Empty: + pass + + if self.batch_wait_timeout_ms == 0: + return items + + if len(items) >= self.max_batch_size: + return items + + deadline = time.monotonic() + self.batch_wait_timeout_ms / 1000.0 + while len(items) < self.max_batch_size: + remaining_timeout = deadline - time.monotonic() + + if remaining_timeout <= 0: break - elif req_item.is_canceled_request: - self.canceled_req_ids.append(req_item.id) - elif req_item.is_control_request: - self.control_requests.append(req_item) - if self.dist.rank == 0: - self.request_accumulated.extend(new_requests[idx + 1:]) + + try: + item = self.request_queue.get(timeout=remaining_timeout) + items.append(item) + except queue.Empty: break - else: - accepted_new_requests.append(req_item) - return accepted_new_requests - - def _balance_requests_across_ranks( - self, new_requests: List[RequestQueueItem], - all_ranks_new_requests: Dict[int, List[RequestQueueItem]], - all_ranks_num_active_requests: List[int], - all_ranks_num_active_tokens: List[int]) -> List[RequestQueueItem]: - """Balance requests across ranks for attention DP.""" - if new_requests: - # Balance context tokens across ranks using heap - HeapVal = namedtuple( - 'HeapVal', - ['num_tokens', 'num_requests', 'rank', 'request_list']) - - all_ranks_new_requests_heap = [ - HeapVal(all_ranks_num_active_tokens[tp_rank], val, tp_rank, []) - for tp_rank, val in enumerate(all_ranks_num_active_requests) - ] - - all_ranks_new_requests_heap = [ - val for val in all_ranks_new_requests_heap - if val.num_requests < self.expected_num_active_requests - ] - - all_ranks_new_scheduled_requests = { - val.rank: val.request_list - for val in all_ranks_new_requests_heap - } - - heapq.heapify(all_ranks_new_requests_heap) - - # Sort by token count (descending) for better load balancing - new_requests = sorted( - new_requests, - key=lambda x: len(getattr(x.request, 'input_token_ids', [])) - if x.request else 0, - reverse=True) - - # Distribute requests across ranks - for req_item in new_requests: - - val = heapq.heappop(all_ranks_new_requests_heap) - token_count = len( - getattr(req_item.request, 'input_token_ids', - [])) if req_item.request else 0 - # Update the heap value with the new request - val = val._replace( - num_tokens=val.num_tokens + token_count, - num_requests=val.num_requests + 1, - ) - - val.request_list.append(req_item) - # If rank still has room for new requests, push back into heap - if val.num_requests < self.expected_num_active_requests: - heapq.heappush(all_ranks_new_requests_heap, val) - - # Extend all_ranks_new_requests with the new requests that have been scheduled - for rank, reqs in all_ranks_new_scheduled_requests.items(): - all_ranks_new_requests[rank].extend(reqs) - - return all_ranks_new_requests - - def _collect_py_objects_from_requests( - self, requests: List[RequestQueueItem], - attribute_name: str) -> Optional[Tuple[str, Dict]]: - """Collect Python-only objects from requests.""" - req_id_to_obj = {} - for item in requests: - if not item.is_normal_request: - continue - if item.request: - obj = getattr(item.request, attribute_name, None) - if obj is not None: - req_id_to_obj[item.id] = obj - return None if not req_id_to_obj else (attribute_name, req_id_to_obj) - - @nvtx_range("_broadcast_new_requests") - def _broadcast_new_requests( - self, new_requests: List[RequestQueueItem], py_request_objects - ) -> Tuple[List[RequestQueueItem], Optional[Dict]]: - """Broadcast new_requests and optional Python-only metadata across pipeline stages.""" - payloads = (new_requests, py_request_objects) - - if not self.dist.has_pp: - return self.dist.broadcast(payloads, root=0) - - # Broadcast within first PP stage before send/recv chain to other PP stages. - # This needs to cover both TP and CP ranks within the first PP stage. - if self.dist.is_first_pp_rank: - payloads = self.dist.tp_cp_broadcast(payloads, root=0) - - # Tag for communication - tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts - - # Send payloads - if not self.dist.is_first_pp_rank: - with nvtx_range("recv_requests_from_prev_pp"): - payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag) - - # isend new requests may cause deadlock, when CUDA_LAUNCH_BLOCKING=1 or PP microbatches can't overlap, - # the deadlock will happen deterministicly: - # 1. rank1 will wait on nccl.send(rank2), without invoking mpi.wait(isend-handle) - # 2. rank2 will wait on mpi.recv(rank1) but never receive the new requests. - # 3. rank1 will hang on nccl.send because rank2 will never reach nccl.recv(rank1). - pp_send_func = self.dist.isend_object if os.environ.get( - "TRTLLM_PP_REQ_SEND_ASYNC", "0") == "1" else self.dist.send_object - - if not self.dist.is_last_pp_rank: - if self.send_requests_handler is not None: - with nvtx_range("wait_prev_send_requests_handler"): - self.send_requests_handler.wait() - with nvtx_range("send_requests_to_next_pp"): - self.send_requests_handler = pp_send_func( - payloads, self.dist.next_pp_rank, tag) - - return payloads - - def _attach_py_objects_to_requests(self, requests: List[RequestQueueItem], - py_request_objects) -> None: - """Attach Python-only objects to each request.""" - for attr_name, req_obj_dict in py_request_objects: - for item in requests: - if item.request: - py_obj = req_obj_dict.get(item.id) - if py_obj is not None: - setattr(item.request, attr_name, py_obj) - - def _update_new_active_requests_queue_latency( - self, new_requests: List[RequestQueueItem]): - """Update queue latency metrics for new requests.""" - now = time.time() - for req_item in new_requests: - if req_item.id in self.start_times: - self.new_active_requests_queue_latency_ms += now - self.start_times.pop( - req_item.id) - if req_item.child_req_ids: - for child_id in req_item.child_req_ids: - self.new_active_requests_queue_latency_ms += now - self.start_times.pop( - child_id) - - # Note: Helix parallelism is a decode-only feature run with disaggregated serving. This function gets called on gen server - # during initialization of a new request. - def _merge_helix_requests(self, new_requests: list[RequestQueueItem], - tokens_per_block: int): - req_with_children = [] - num_cp_ranks = self.dist.cp_size - curr_cp_rank = self.dist.cp_rank - - # For each request, partition the input_token_ids into blocks and then partition blocks across CP ranks. - # Currently, the partitioning is such that contiguous blocks are assigned to the same CP rank (as opposed - # to round-robin). - for req_item in new_requests: - all_input_ids = torch.tensor(req_item.request.input_token_ids, - dtype=torch.int64).unsqueeze(0) - input_len = all_input_ids.shape[-1] - - num_total_blocks = (input_len + tokens_per_block - - 1) // tokens_per_block - if num_total_blocks < num_cp_ranks: - raise ValueError( - f"There aren't enough tokens to get at least one block per CP rank. num_total_blocks {num_total_blocks} < num_cp_ranks {num_cp_ranks}. Please use smaller tokens_per_block for KV cache or reduce the number of CP ranks." - ) - - # Padding to ensure torch.stack used with torch.tensor_split works properly. - padding_len = 0 - if input_len % tokens_per_block != 0: - padding_len = tokens_per_block - (input_len % tokens_per_block) - padding_ids = torch.zeros([1, padding_len], dtype=torch.int64) - all_input_ids = torch.cat((all_input_ids, padding_ids), dim=-1) - all_position_ids = torch.arange(0, - input_len + padding_len, - dtype=torch.int64).unsqueeze(0) - - input_id_blocks_per_rank = torch.tensor_split( - torch.stack(all_input_ids.split(tokens_per_block, dim=-1)), - num_cp_ranks) - position_id_blocks_per_rank = torch.tensor_split( - torch.stack(all_position_ids.split(tokens_per_block, dim=-1)), - num_cp_ranks) - - # Get the input_ids and position_ids for this rank. - input_ids_this_rank = input_id_blocks_per_rank[ - curr_cp_rank].flatten().tolist() - position_ids_this_rank = position_id_blocks_per_rank[ - curr_cp_rank].flatten().tolist() - - # Undo the padding. Only last rank's last block will be padded right now - # given contiguous block assignment. - if curr_cp_rank == num_cp_ranks - 1 and padding_len > 0: - input_ids_this_rank = input_ids_this_rank[:-padding_len] - position_ids_this_rank = position_ids_this_rank[:-padding_len] - - req = executor_request_to_llm_request( - req_id=req_item.id, - executor_request=req_item.request, - child_req_ids=req_item.child_req_ids, - exclude_last_generation_logits=self. - _should_exclude_last_generation_logits(), - input_token_ids=input_ids_this_rank, - position_ids=position_ids_this_rank, - ) - req.total_input_len_cp = input_len - req.seqlen_this_rank_cp = len(input_ids_this_rank) - req_with_children.append(req) - if req.child_requests: - req_with_children.extend(req.child_requests) - return req_with_children - - @nvtx_range("_merge_requests") - def _merge_requests( - self, new_requests: list[RequestQueueItem]) -> List[LlmRequest]: - cp_config = self.dist.cp_config - if 'cp_type' in cp_config: - cp_type = cp_config['cp_type'] - if cp_type == CpType.STAR: - return self._merge_star_attention_requests(new_requests) - elif cp_type == CpType.HELIX: - return self._merge_helix_requests( - new_requests, - tokens_per_block=cp_config['tokens_per_block']) - else: - raise NotImplementedError( - f'Unsupported cp type {cp_type.name}.') - - req_with_children = [] - for req_item in new_requests: - req = executor_request_to_llm_request( - req_item.id, req_item.request, req_item.child_req_ids, - self._should_exclude_last_generation_logits()) - req_with_children.append(req) - if req.child_requests: - req_with_children.extend(req.child_requests) - return req_with_children - - def _merge_star_attention_requests( - self, new_requests: list[RequestQueueItem]) -> List[LlmRequest]: - result = [] - for req_item in new_requests: - req_id, exe_req, query_token_ids = req_item.id, req_item.request, req_item.query - ctx_len0 = len(exe_req.input_token_ids) - ctx_blocks, position_blocks, last_block_padding_num = [ - exe_req.input_token_ids - ], [[i for i in range(ctx_len0)]], 0 - ctx_blocks, position_blocks, last_block_padding_num = self._partition_context( - exe_req.input_token_ids) - if self.dist.cp_rank == self.dist.cp_size - 1 and last_block_padding_num > 0: - ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num] - position_blocks[-1] = position_blocks[ - -1][:-last_block_padding_num] - #if has query - if query_token_ids: - ctx_blocks.append(query_token_ids) - position_blocks.append([ - i for i in range(ctx_len0, ctx_len0 + len(query_token_ids)) - ]) - - # insert the dummy block to align the number of ctx iterations of each rank - block_size = self.dist.cp_config['block_size'] - total_blocks = (ctx_len0 + block_size - 1) // block_size - num_blocks_per_rank = ( - total_blocks + self.dist.cp_size - - 1) // self.dist.cp_size + 1 # 1 for query block - if len(ctx_blocks) == num_blocks_per_rank: - ctx_blocks.insert(1, []) - position_blocks.insert(1, []) - elif len(ctx_blocks) == num_blocks_per_rank + 1: - # anchor + ctx_blocks + qry_block - pass - else: - print( - f'rank = {self.dist.cp_rank}, len(ctx_blocks) = {len(ctx_blocks) }, num_blocks_per_rank = {num_blocks_per_rank}' - ) - assert False, f'invalid context partition' - - # fake data for scheduler - ctx_blocks_list = [0] * (block_size + - self.dist.cp_config['cp_anchor_size']) - - req = executor_request_to_llm_request( - req_id, exe_req, self._should_exclude_last_generation_logits(), - ctx_blocks_list) - req.gen_iters = 0 - req.ctx_iters = 0 - req.ctx_blocks = ctx_blocks - req.ctx_position_blocks = position_blocks - req.query_id = query_token_ids - - result.append(req) - - return result - - def _partition_context(self, ctx_ids_list): - ctx_ids = torch.tensor(ctx_ids_list).unsqueeze(0) - ctx_len = ctx_ids.shape[-1] - block_size = self.dist.cp_config['block_size'] - if block_size is None: - block_size = ctx_len // self.dist.cp_size - anchor_block_size = self.dist.cp_config['cp_anchor_size'] - if anchor_block_size is None: - anchor_block_size = block_size - - assert anchor_block_size <= block_size, f'cp_anchor_size {anchor_block_size} should be smaller than block_size {block_size}' - padding = 0 - if ctx_len % block_size != 0: - padding = block_size - (ctx_len % block_size) - assert padding <= ctx_len, f'block size is too large for context, please set it smaller' - ctx_ids = torch.cat( - (ctx_ids, torch.zeros_like(ctx_ids)[:, :padding]), dim=-1) - position_ids = torch.arange(0, ctx_ids.shape[-1]).unsqueeze(0) - - ctx_ids_blocks = torch.tensor_split( - torch.stack(ctx_ids.split(block_size, dim=-1)), self.dist.cp_size) - position_ids_blocks = torch.tensor_split( - torch.stack(position_ids.split(block_size, dim=-1)), - self.dist.cp_size) - if self.dist.cp_rank != 0: - ctx_blocks, position_blocks = [ - ctx_ids_blocks[0][0].tolist()[0][:anchor_block_size] - ], [position_ids_blocks[0][0].tolist()[0][:anchor_block_size]] - else: - ctx_blocks, position_blocks = [], [] - - for idx in range(len(ctx_ids_blocks[self.dist.cp_rank])): - ctx_block = ctx_ids_blocks[self.dist.cp_rank][idx] - position_block = position_ids_blocks[self.dist.cp_rank][idx] - ctx_blocks.append(ctx_block.tolist()[0]) - position_blocks.append(position_block.tolist()[0]) - return ctx_blocks, position_blocks, padding - - def set_exclude_last_generation_logits(self, - disable_overlap_scheduler: bool, - pp_size: int) -> None: - # When overlap scheduler is enabled then when starting to handle a new prompt, - # sample_async is called twice before the first call to update_requests: - # - 1st time as a context request that handles on the 1st generated token - # - 2nd time as a generation request that handles on the 2nd generated token. - # and only after these two calls the sampler's update_request method is called. - # So in a sampler that works by the expected flow of handling the logits in - # sample_async, every update_request doesn't handle the newest token, but one - # before it. Since all these calls work on the same request object, then its - # logits storage contains the logits of both the token update_requests should work - # on, and also its next token. Thus, excluding the last generation logits from any - # getter is required. - self.should_exclude_last_generation_logits = not disable_overlap_scheduler and pp_size == 1 - - def _should_exclude_last_generation_logits(self) -> bool: - return self.should_exclude_last_generation_logits - - def get_new_active_requests_queue_latency(self) -> float: - return self.new_active_requests_queue_latency_ms - - def get_expected_num_active_requests(self) -> int: - return self.expected_num_active_requests + return items def get_request_queue_size(self) -> int: return self.request_queue.qsize() @@ -891,22 +193,22 @@ class ExecutorRequestQueue: def get_request_queue(self) -> queue.Queue[RequestQueueItem]: return self.request_queue - def get_waiting_queue(self) -> deque[RequestQueueItem]: - return self.waiting_queue + def calculate_queue_latency(self, request_items: List[RequestQueueItem], + now: float) -> float: + if not self.enable_iter_perf_stats: + return 0.0 - def update_waiting_queue(self): - # Remove cancel request in the waiting queue - self.waiting_queue = deque(req for req in self.waiting_queue - if req.id not in self.canceled_req_ids) + total_latency = 0.0 - def get_waiting_queue_size(self) -> int: - return len(self.waiting_queue) + for req_item in request_items: + # Handle parent request + if req_item.id in self.start_times: + total_latency += now - self.start_times.pop(req_item.id) - def get_canceled_req_ids_size(self) -> int: - return len(self.canceled_req_ids) + # Handle child requests + if req_item.child_req_ids: + for child_id in req_item.child_req_ids: + if child_id in self.start_times: + total_latency += now - self.start_times.pop(child_id) - def get_canceled_req_ids(self) -> List[int]: - return self.canceled_req_ids - - def clear_canceled_req_ids(self): - self.canceled_req_ids.clear() + return total_latency diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 44e0761da6..2c0593d651 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -6,6 +6,7 @@ import pickle # nosec B403 import threading import time import traceback +from collections import deque from contextlib import contextmanager from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -22,7 +23,7 @@ except ImportError: from tensorrt_llm._torch.pyexecutor.resource_manager import ( ResourceManagerType, request_context) from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled, - nvtx_range, trace_func) + mpi_disabled, nvtx_range, trace_func) from tensorrt_llm.bindings.executor import (DisServingRequestStats, FinishReason, InflightBatchingStats, IterationStats, KvCacheStats, @@ -53,6 +54,9 @@ from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, LlmResponse, get_draft_token_length) from .model_engine import ModelEngine +from .request_utils import (RequestBroadcaster, attach_py_objects_to_requests, + get_from_waiting_queue, merge_requests, + schedule_attention_dp_requests) from .resource_manager import ResourceManager from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState, SampleStateTensors, TRTLLMSampler) @@ -424,16 +428,36 @@ class PyExecutor: self._set_global_steady_clock_offset() self.executor_request_queue = ExecutorRequestQueue( dist=self.dist, - enable_attention_dp=self.enable_attention_dp, max_batch_size=max_batch_size, - max_beam_width=self.max_beam_width, - max_num_active_requests=self.max_num_active_requests, enable_iter_perf_stats=self.enable_iter_perf_stats, batch_wait_timeout_ms=self.batch_wait_timeout_ms, - hang_detector=self.hang_detector, ) - self.executor_request_queue.set_exclude_last_generation_logits( - self.disable_overlap_scheduler, self.dist.pp_size) + # When overlap scheduler is enabled then when starting to handle a new prompt, + # sample_async is called twice before the first call to update_requests: + # - 1st time as a context request that handles on the 1st generated token + # - 2nd time as a generation request that handles on the 2nd generated token. + # and only after these two calls the sampler's update_request method is called. + # So in a sampler that works by the expected flow of handling the logits in + # sample_async, every update_request doesn't handle the newest token, but one + # before it. Since all these calls work on the same request object, then its + # logits storage contains the logits of both the token update_requests should work + # on, and also its next token. Thus, excluding the last generation logits from any + # getter is required. + self.should_exclude_last_generation_logits = ( + not self.disable_overlap_scheduler and self.dist.pp_size == 1) + + # Request processing state (managed by executor) + self.canceled_req_ids: List[int] = [] + self.control_requests: List[RequestQueueItem] = [] + self.request_accumulated: List[RequestQueueItem] = [] + self.new_active_requests_queue_latency_ms = 0.0 + self._disable_mpi = mpi_disabled() + self.request_broadcaster = RequestBroadcaster(self.dist, + self.hang_detector) + + # Waiting queue for requests that have been fetched but not yet scheduled + self.waiting_queue: deque[RequestQueueItem] = deque() + self.control_request_barrier = threading.Event() self.control_action_done = threading.Event() @@ -698,7 +722,7 @@ class PyExecutor: @property def should_stop_processing(self): return self.is_shutdown and len(self.active_requests) == 0 and \ - self.executor_request_queue.get_waiting_queue_size() == 0 + len(self.waiting_queue) == 0 @contextmanager def _profiler(self): @@ -778,8 +802,8 @@ class PyExecutor: f"iter = {self.iter_counter}, " f"global_rank = {self.global_rank}, " f"rank = {self.dist.rank}, " - f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/" - f"{self.executor_request_queue.num_fetch_requests}, " + f"currank_total_requests = {self.num_fetch_requests_cur_rank}/" + f"{self.num_fetch_requests}, " f"host_step_time = {host_step_time}ms, " f"prev_device_step_time = {prev_device_step_time}, " f"timestamp = {formatted_timestamp}, " @@ -1143,8 +1167,7 @@ class PyExecutor: if self.enable_iter_perf_stats: iter_stats = self._get_init_iter_stats( len(new_requests), - self.executor_request_queue. - get_new_active_requests_queue_latency()) + self._get_new_active_requests_queue_latency()) self._pad_attention_dp_dummy_request() @@ -1400,8 +1423,7 @@ class PyExecutor: if self.enable_iter_perf_stats: iter_stats = self._get_init_iter_stats( len(new_requests), - self.executor_request_queue. - get_new_active_requests_queue_latency()) + self._get_new_active_requests_queue_latency()) self._pad_attention_dp_dummy_request() @@ -1639,14 +1661,14 @@ class PyExecutor: def _handle_control_request(self): if len(self.active_requests) == 0 and \ - self.executor_request_queue.get_waiting_queue_size() == 0 and \ - len(self.executor_request_queue.control_requests) > 0: - assert len(self.executor_request_queue.control_requests) == 1, ( + len(self.waiting_queue) == 0 and \ + len(self.control_requests) > 0: + assert len(self.control_requests) == 1, ( f"Expected exactly one control request to be processed at a time, " - f"but found {len(self.executor_request_queue.control_requests)} control requests. " + f"but found {len(self.control_requests)} control requests. " f"This may indicate a race condition or improper control request handling." ) - self.executor_request_queue.control_requests.pop(0) + self.control_requests.pop(0) self.control_request_barrier.set() self.control_action_done.wait() self.control_action_done.clear() @@ -1704,7 +1726,7 @@ class PyExecutor: # to ensure consistent batch sizes for accurate performance measurement. if not self.is_warmup and not can_forward: if self.enable_attention_dp: - local_can_forward = self.executor_request_queue.num_fetch_requests + \ + local_can_forward = self.num_fetch_requests + \ len(scheduled_batch.generation_requests) >= self.benchmark_req_queues_size all_can_forward = self.dist.tp_allgather( local_can_forward) @@ -1714,7 +1736,7 @@ class PyExecutor: else: if self.dist.rank == 0: logger.info( - f"sleep 10 seconds, num_fetched_requests: {self.executor_request_queue.num_fetch_requests}, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}" + f"sleep 10 seconds, num_fetched_requests: {self.num_fetch_requests}, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}" ) time.sleep(10) continue @@ -2026,6 +2048,169 @@ class PyExecutor: self.model_engine.model.lm_head.num_embeddings): raise ValueError("Token ID out of range") + def _fetch_and_enqueue_requests(self, + waiting_queue: deque[RequestQueueItem], + total_num_active_requests: int) -> None: + """Fetch requests from request_queue and enqueue to waiting_queue.""" + # Block new requests while control requests are pending + if len(self.control_requests) != 0: + return + + # Calculate timeout + idle = (total_num_active_requests == 0) and len(waiting_queue) == 0 + if idle: + # In Ray path (TLLM_DISABLE_MPI=1), use a periodic heartbeat timeout so rank 0 + # reaches the broadcast path regularly to prevent trtllm-serve timeout when idle. + timeout = datetime.timedelta( + seconds=1200) if self._disable_mpi else None + else: + timeout = datetime.timedelta(0) + + # Fetch requests from rank 0 + new_requests = [] + if self.dist.rank == 0: + # Process accumulated requests that were queued during control request handling. + if len(self.request_accumulated) != 0: + new_requests.extend(self.request_accumulated) + self.request_accumulated.clear() + # Reset timeout to 0 to avoid hanging when no new requests are available + timeout = datetime.timedelta(0) + with self.hang_detector.pause(): + new_requests.extend( + self.executor_request_queue.get_from_request_queue(timeout)) + + # Broadcast requests and handle Python objects + new_requests, py_request_objects = self.request_broadcaster.broadcast( + new_requests) + + # Validate and filter requests + new_requests = self._handle_special_queue_items(new_requests) + + # Attach Python objects to requests + if py_request_objects and (self.dist.tp_size > 1 or self.dist.has_pp + or self.dist.cp_size + > 1) and self.dist.rank > 0: + attach_py_objects_to_requests(new_requests, py_request_objects) + + waiting_queue.extend(new_requests) + + def _pop_from_waiting_queue( + self, + waiting_queue: deque[RequestQueueItem], + total_num_active_requests: int, + all_ranks_num_active_requests: Optional[List[int]] = None + ) -> List[RequestQueueItem]: + """Pop requests from waiting_queue based on available capacity.""" + if self.enable_attention_dp: + total_max = self.dist.tp_size * self.max_num_active_requests + else: + total_max = self.max_num_active_requests + + max_new_requests = total_max - total_num_active_requests + + return get_from_waiting_queue( + waiting_queue, + max_new_requests, + enable_attention_dp=self.enable_attention_dp, + max_num_active_requests=self.max_num_active_requests, + all_ranks_num_active_requests=all_ranks_num_active_requests) + + @nvtx_range("_fetch_new_requests") + def _fetch_new_requests( + self, waiting_queue: deque[RequestQueueItem], + activate_requests: List[LlmRequest]) -> List[LlmRequest]: + """Fetch new requests and return LlmRequests ready for execution.""" + # 1. Gather info and calculate total_num_active_requests + if self.enable_attention_dp: + all_ranks_num_active_requests = [] + all_ranks_num_active_tokens = [] + if self.dist.has_cp_helix: + num_active_tokens = sum( + [req.total_input_len_cp for req in activate_requests]) + else: + num_active_tokens = sum( + [req.py_orig_prompt_len for req in activate_requests]) + responses_list = self.dist.tp_allgather( + [len(activate_requests), num_active_tokens]) + for num_active_requests, num_active_tokens in responses_list: + all_ranks_num_active_requests.append(num_active_requests) + all_ranks_num_active_tokens.append(num_active_tokens) + total_num_active_requests = sum(all_ranks_num_active_requests) + else: + total_num_active_requests = len(activate_requests) + all_ranks_num_active_requests = None + + # 2. Fetch and enqueue to waiting queue + self._fetch_and_enqueue_requests(waiting_queue, + total_num_active_requests) + + # 3. Pop requests from waiting queue + new_requests = self._pop_from_waiting_queue( + waiting_queue, total_num_active_requests, + all_ranks_num_active_requests) + + # 4. Update performance metrics (before DP scheduling to clear all start_times) + if self.enable_iter_perf_stats and self.dist.rank == 0: + self._update_new_active_requests_queue_latency(new_requests) + + # 5. Schedule requests across ranks (DP only) + if self.enable_attention_dp: + all_ranks_new_requests, self.expected_num_active_requests = \ + schedule_attention_dp_requests( + new_requests, all_ranks_num_active_requests, + all_ranks_num_active_tokens, self.dist.tp_size, + self.max_num_active_requests) + new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank] + + # Update counters for DP + self.num_fetch_requests += len(new_requests) + self.num_fetch_requests_cur_rank += len(new_requests_cur_rank) + + new_requests = new_requests_cur_rank + + # 6. Merge requests + return merge_requests(new_requests, + cp_config=self.dist.cp_config, + cp_rank=self.dist.cp_rank, + cp_size=self.dist.cp_size, + exclude_last_generation_logits=self. + _should_exclude_last_generation_logits()) + + def _handle_special_queue_items( + self, + new_requests: List[RequestQueueItem]) -> List[RequestQueueItem]: + """Handle special signals.""" + accepted_new_requests = [] + for idx, req_item in enumerate(new_requests): + if req_item.is_shutdown_request: + self.is_shutdown = True + break + elif req_item.is_canceled_request: + self.canceled_req_ids.append(req_item.id) + elif req_item.is_control_request: + self.control_requests.append(req_item) + if self.dist.rank == 0: + self.request_accumulated.extend(new_requests[idx + 1:]) + break + else: + accepted_new_requests.append(req_item) + + return accepted_new_requests + + def _update_new_active_requests_queue_latency( + self, new_requests: List[RequestQueueItem]): + """Update queue latency metrics for new requests.""" + now = time.time() + latency = self.executor_request_queue.calculate_queue_latency( + new_requests, now) + self.new_active_requests_queue_latency_ms += latency + + def _get_new_active_requests_queue_latency(self) -> float: + return self.new_active_requests_queue_latency_ms + + def _should_exclude_last_generation_logits(self) -> bool: + return self.should_exclude_last_generation_logits + def _fetch_and_activate_new_requests(self) -> List[LlmRequest]: def _respond_if_invalid(request: LlmRequest) -> bool: @@ -2041,11 +2226,8 @@ class PyExecutor: self._handle_errors(str(e), requests=[request]) return True - new_requests_cur_rank = self.executor_request_queue.fetch_new_requests( - self.active_requests) - self.is_shutdown = self.executor_request_queue.is_shutdown - self.expected_num_active_requests = self.executor_request_queue.get_expected_num_active_requests( - ) + new_requests_cur_rank = self._fetch_new_requests( + self.waiting_queue, self.active_requests) validated_requests = [ request for request in new_requests_cur_rank @@ -2647,20 +2829,20 @@ class PyExecutor: @nvtx_range("_handle_canceled_requests") def _handle_canceled_requests(self): - if self.executor_request_queue.get_canceled_req_ids_size() == 0: + if len(self.canceled_req_ids) == 0: return - # Remove cancel request in the waiting queue - self.executor_request_queue.update_waiting_queue() - # Create set from list of canceled request ids to speed up canceled test - canceled_req_ids = set( - self.executor_request_queue.get_canceled_req_ids()) + canceled_req_ids_set = set(self.canceled_req_ids) + + # Remove canceled requests from the waiting queue + self.waiting_queue = deque(req for req in self.waiting_queue + if req.id not in canceled_req_ids_set) still_pending_canceled_ids = [] for request in self.active_requests: req_id = request.py_request_id if not request.is_child else request.parent_request_id - if req_id not in canceled_req_ids: + if req_id not in canceled_req_ids_set: continue is_cancelled = self._try_cancel_request(request) @@ -2673,9 +2855,8 @@ class PyExecutor: still_pending_canceled_ids.append(req_id) # Clear list of requests marked for cancellation and add back those that failed to cancel. - self.executor_request_queue.canceled_req_ids.clear() - self.executor_request_queue.canceled_req_ids.extend( - still_pending_canceled_ids) + self.canceled_req_ids.clear() + self.canceled_req_ids.extend(still_pending_canceled_ids) @nvtx_range("_enqueue_responses") def _enqueue_responses(self, responses: Iterable[Tuple[int, LlmResponse]]): diff --git a/tensorrt_llm/_torch/pyexecutor/request_utils.py b/tensorrt_llm/_torch/pyexecutor/request_utils.py new file mode 100644 index 0000000000..cae311a395 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/request_utils.py @@ -0,0 +1,704 @@ +"""Utility functions for request processing.""" + +import heapq +import os +from collections import deque, namedtuple +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from tensorrt_llm._utils import nvtx_range +from tensorrt_llm.mapping import CpType + +from ..distributed import Distributed +from .hang_detector import HangDetector +from .llm_request import ExecutorRequest, LlmRequest, executor_request_to_llm_request + +# Type alias for request queue items (to avoid circular import) +# The actual RequestQueueItem class is defined in executor_request_queue.py + + +def get_num_child_requests(request: ExecutorRequest) -> int: + """Get the number of child requests for a given request. + + Args: + request: The executor request to check. + + Returns: + Number of child requests (0 if beam search, otherwise num_return_sequences - 1). + """ + sampling_config = request.sampling_config + return 0 if sampling_config.beam_width > 1 else (sampling_config.num_return_sequences or 1) - 1 + + +def collect_py_objects_from_requests( + requests: List, attribute_name: str +) -> Optional[Tuple[str, Dict]]: + """Collect Python-only objects from requests. + + Args: + requests: List of RequestQueueItem objects. + attribute_name: Name of the attribute to collect. + + Returns: + Tuple of (attribute_name, dict mapping request_id to object) or None if empty. + """ + req_id_to_obj = {} + for item in requests: + if not item.is_normal_request: + continue + if item.request: + obj = getattr(item.request, attribute_name, None) + if obj is not None: + req_id_to_obj[item.id] = obj + return None if not req_id_to_obj else (attribute_name, req_id_to_obj) + + +def attach_py_objects_to_requests(requests: List, py_request_objects: Tuple) -> None: + """Attach Python-only objects to each request. + + Args: + requests: List of RequestQueueItem objects. + py_request_objects: Tuple of (attribute_name, dict) pairs. + """ + for attr_name, req_obj_dict in py_request_objects: + for item in requests: + if item.request: + py_obj = req_obj_dict.get(item.id) + if py_obj is not None: + setattr(item.request, attr_name, py_obj) + + +def schedule_attention_dp_requests( + new_requests: List[Any], + all_ranks_num_active_requests: List[int], + all_ranks_num_active_tokens: List[int], + tp_size: int, + max_num_active_requests: int, +) -> Tuple[Dict[int, List[Any]], int]: + """Schedule attention DP requests across ranks. + + This function distributes requests across tensor parallel ranks for attention DP. + It first tries to assign requests to their target dp_rank (if specified and has capacity), + then balances the remaining requests across all ranks. + + Args: + new_requests: List of RequestQueueItem to schedule. + all_ranks_num_active_requests: Number of active requests per rank (will be modified). + all_ranks_num_active_tokens: Number of active tokens per rank. + tp_size: Number of tensor parallel ranks. + max_num_active_requests: Maximum number of active requests per rank. + + Returns: + Tuple of: + - all_ranks_new_requests: Dict mapping rank to list of assigned requests. + - expected_num_active_requests: Expected number of active requests per rank. + """ + # Map from ranks to new requests + all_ranks_new_requests = {tp_rank: [] for tp_rank in range(tp_size)} + + # Prioritize the requests that are not in relax mode + def get_relax_value(req_item): + scheduling_params = getattr(req_item.request, "py_scheduling_params", None) + if scheduling_params is None: + return True + return scheduling_params.attention_dp_relax + + new_requests = sorted(new_requests, key=get_relax_value) + + # Try to put the requests to the target dp rank until the max_num_active_requests is reached + remaining_unscheduled = [] + for req_item in new_requests: + scheduled = False + scheduling_params = getattr(req_item.request, "py_scheduling_params", None) + if scheduling_params is not None: + target_dp_rank = scheduling_params.attention_dp_rank + if ( + target_dp_rank is not None + and all_ranks_num_active_requests[target_dp_rank] < max_num_active_requests + ): + all_ranks_num_active_requests[target_dp_rank] += 1 + scheduled = True + all_ranks_new_requests[target_dp_rank].append(req_item) + + if not scheduled: + remaining_unscheduled.append(req_item) + + # Balance the remaining unscheduled requests across ranks + num_new_requests_all_ranks = len(remaining_unscheduled) + total_num_active_requests = sum(all_ranks_num_active_requests) + expected_num_active_requests = max( + (total_num_active_requests + num_new_requests_all_ranks + tp_size - 1) // tp_size, + max(all_ranks_num_active_requests), + ) + + all_ranks_new_requests = balance_requests_across_ranks( + remaining_unscheduled, + all_ranks_new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + expected_num_active_requests, + ) + + return all_ranks_new_requests, expected_num_active_requests + + +def balance_requests_across_ranks( + new_requests: List, + all_ranks_new_requests: Dict[int, List], + all_ranks_num_active_requests: List[int], + all_ranks_num_active_tokens: List[int], + expected_num_active_requests: int, +) -> Dict[int, List]: + """Balance requests across ranks for attention DP. + + Uses a heap-based algorithm to distribute requests evenly across ranks, + prioritizing ranks with fewer tokens for better load balancing. + + Args: + new_requests: List of new requests to distribute. + all_ranks_new_requests: Dict mapping rank to list of already assigned requests. + all_ranks_num_active_requests: Number of active requests per rank. + all_ranks_num_active_tokens: Number of active tokens per rank. + expected_num_active_requests: Target number of active requests per rank. + + Returns: + Updated all_ranks_new_requests dict with new requests distributed. + """ + if new_requests: + # Balance context tokens across ranks using heap + HeapVal = namedtuple("HeapVal", ["num_tokens", "num_requests", "rank", "request_list"]) + + all_ranks_new_requests_heap = [ + HeapVal(all_ranks_num_active_tokens[tp_rank], val, tp_rank, []) + for tp_rank, val in enumerate(all_ranks_num_active_requests) + ] + + all_ranks_new_requests_heap = [ + val + for val in all_ranks_new_requests_heap + if val.num_requests < expected_num_active_requests + ] + + all_ranks_new_scheduled_requests = { + val.rank: val.request_list for val in all_ranks_new_requests_heap + } + + heapq.heapify(all_ranks_new_requests_heap) + + # Sort by token count (descending) for better load balancing + new_requests = sorted( + new_requests, + key=lambda x: len(getattr(x.request, "input_token_ids", [])) if x.request else 0, + reverse=True, + ) + + # Distribute requests across ranks + for req_item in new_requests: + val = heapq.heappop(all_ranks_new_requests_heap) + token_count = ( + len(getattr(req_item.request, "input_token_ids", [])) if req_item.request else 0 + ) + # Update the heap value with the new request + val = val._replace( + num_tokens=val.num_tokens + token_count, + num_requests=val.num_requests + 1, + ) + + val.request_list.append(req_item) + # If rank still has room for new requests, push back into heap + if val.num_requests < expected_num_active_requests: + heapq.heappush(all_ranks_new_requests_heap, val) + + # Extend all_ranks_new_requests with the new requests that have been scheduled + for rank, reqs in all_ranks_new_scheduled_requests.items(): + all_ranks_new_requests[rank].extend(reqs) + + return all_ranks_new_requests + + +def can_process_attention_dp_request( + req_item, all_ranks_num_active_requests: List[int], max_num_active_requests: int +) -> bool: + """Check if a request can be processed immediately for attention DP. + + Args: + req_item: The request queue item to check. + all_ranks_num_active_requests: Number of active requests for each rank. + max_num_active_requests: Maximum number of active requests per rank. + + Returns: + True if the request can be processed, False otherwise. + """ + scheduling_params = getattr(req_item.request, "py_scheduling_params", None) + if scheduling_params is None: + return True + + target_dp_rank = scheduling_params.attention_dp_rank + if target_dp_rank is None or scheduling_params.attention_dp_relax: + return True + + if all_ranks_num_active_requests[target_dp_rank] < max_num_active_requests: + all_ranks_num_active_requests[target_dp_rank] += 1 + return True + + return False + + +def get_from_waiting_queue( + waiting_queue: deque, + max_req_count: int, + enable_attention_dp: bool, + max_num_active_requests: int, + all_ranks_num_active_requests: Optional[List[int]] = None, +) -> List: + """Get requests from the waiting queue. + + Args: + waiting_queue: The queue to pop items from. + max_req_count: Maximum items to retrieve. Returns empty list if <=0. + enable_attention_dp: Whether to enable attention DP scheduling. + max_num_active_requests: Maximum number of active requests per rank. + all_ranks_num_active_requests: Number of active requests for each rank. + + Returns: + List of requests that can be processed. + """ + if max_req_count <= 0: + return [] + + req_count = 0 + items = [] + pending_requests = [] + + # Track the request with strict requirements + scheduling_all_ranks_num_active_requests = ( + all_ranks_num_active_requests.copy() if enable_attention_dp else None + ) + + while req_count < max_req_count and waiting_queue: + req_item = waiting_queue[0] + num_children = len(req_item.child_req_ids) if req_item.child_req_ids else 0 + if (req_count + 1 + num_children) > max_req_count: + break + req_item = waiting_queue.popleft() + + can_process = ( + can_process_attention_dp_request( + req_item, scheduling_all_ranks_num_active_requests, max_num_active_requests + ) + if enable_attention_dp + else True + ) + + if can_process: + items.append(req_item) + req_count += 1 + num_children + else: + pending_requests.append(req_item) + + # Put the pending requests back to the waiting queue + # All ranks should have the same waiting queue + waiting_queue.extendleft(reversed(pending_requests)) + + return items + + +def partition_context_for_star_attention( + ctx_ids_list: List[int], cp_rank: int, cp_size: int, block_size: int, anchor_block_size: int +) -> Tuple[List[List[int]], List[List[int]], int]: + """Partition context for Star Attention CP. + + Args: + ctx_ids_list: List of context token IDs. + cp_rank: Current CP rank. + cp_size: Total number of CP ranks. + block_size: Size of each block. + anchor_block_size: Size of anchor block. + + Returns: + Tuple of (ctx_blocks, position_blocks, padding). + """ + ctx_ids = torch.tensor(ctx_ids_list).unsqueeze(0) + ctx_len = ctx_ids.shape[-1] + + if block_size is None: + block_size = ctx_len // cp_size + if anchor_block_size is None: + anchor_block_size = block_size + + assert anchor_block_size <= block_size, ( + f"cp_anchor_size {anchor_block_size} should be smaller than block_size {block_size}" + ) + + padding = 0 + if ctx_len % block_size != 0: + padding = block_size - (ctx_len % block_size) + assert padding <= ctx_len, "block size is too large for context, please set it smaller" + ctx_ids = torch.cat((ctx_ids, torch.zeros_like(ctx_ids)[:, :padding]), dim=-1) + position_ids = torch.arange(0, ctx_ids.shape[-1]).unsqueeze(0) + + ctx_ids_blocks = torch.tensor_split(torch.stack(ctx_ids.split(block_size, dim=-1)), cp_size) + position_ids_blocks = torch.tensor_split( + torch.stack(position_ids.split(block_size, dim=-1)), cp_size + ) + + if cp_rank != 0: + ctx_blocks = [ctx_ids_blocks[0][0].tolist()[0][:anchor_block_size]] + position_blocks = [position_ids_blocks[0][0].tolist()[0][:anchor_block_size]] + else: + ctx_blocks, position_blocks = [], [] + + for idx in range(len(ctx_ids_blocks[cp_rank])): + ctx_block = ctx_ids_blocks[cp_rank][idx] + position_block = position_ids_blocks[cp_rank][idx] + ctx_blocks.append(ctx_block.tolist()[0]) + position_blocks.append(position_block.tolist()[0]) + + return ctx_blocks, position_blocks, padding + + +def partition_context_for_helix( + input_token_ids: List[int], cp_rank: int, cp_size: int, tokens_per_block: int +) -> Tuple[List[int], List[int], int, int]: + """Partition context for Helix CP. + + Args: + input_token_ids: List of input token IDs. + cp_rank: Current CP rank. + cp_size: Total number of CP ranks. + tokens_per_block: Number of tokens per block. + + Returns: + Tuple of (input_ids_this_rank, position_ids_this_rank, input_len, padding_len). + + Raises: + ValueError: If there aren't enough tokens for at least one block per CP rank. + """ + all_input_ids = torch.tensor(input_token_ids, dtype=torch.int64).unsqueeze(0) + input_len = all_input_ids.shape[-1] + + num_total_blocks = (input_len + tokens_per_block - 1) // tokens_per_block + if num_total_blocks < cp_size: + raise ValueError( + f"There aren't enough tokens to get at least one block per CP rank. " + f"num_total_blocks {num_total_blocks} < num_cp_ranks {cp_size}. " + f"Please use smaller tokens_per_block for KV cache or reduce the number of CP ranks." + ) + + # Padding to ensure torch.stack used with torch.tensor_split works properly. + padding_len = 0 + if input_len % tokens_per_block != 0: + padding_len = tokens_per_block - (input_len % tokens_per_block) + padding_ids = torch.zeros([1, padding_len], dtype=torch.int64) + all_input_ids = torch.cat((all_input_ids, padding_ids), dim=-1) + all_position_ids = torch.arange(0, input_len + padding_len, dtype=torch.int64).unsqueeze(0) + + input_id_blocks_per_rank = torch.tensor_split( + torch.stack(all_input_ids.split(tokens_per_block, dim=-1)), cp_size + ) + position_id_blocks_per_rank = torch.tensor_split( + torch.stack(all_position_ids.split(tokens_per_block, dim=-1)), cp_size + ) + + # Get the input_ids and position_ids for this rank. + input_ids_this_rank = input_id_blocks_per_rank[cp_rank].flatten().tolist() + position_ids_this_rank = position_id_blocks_per_rank[cp_rank].flatten().tolist() + + # Undo the padding. Only last rank's last block will be padded right now + # given contiguous block assignment. + if cp_rank == cp_size - 1 and padding_len > 0: + input_ids_this_rank = input_ids_this_rank[:-padding_len] + position_ids_this_rank = position_ids_this_rank[:-padding_len] + + return input_ids_this_rank, position_ids_this_rank, input_len, padding_len + + +def merge_requests_to_llm_requests( + new_requests: List, exclude_last_generation_logits: bool +) -> List[LlmRequest]: + """Merge RequestQueueItems to LlmRequests (basic case without CP). + + Args: + new_requests: List of RequestQueueItem objects. + exclude_last_generation_logits: Whether to exclude last generation logits. + + Returns: + List of LlmRequest objects including child requests. + """ + req_with_children = [] + for req_item in new_requests: + req = executor_request_to_llm_request( + req_item.id, req_item.request, req_item.child_req_ids, exclude_last_generation_logits + ) + req_with_children.append(req) + if req.child_requests: + req_with_children.extend(req.child_requests) + return req_with_children + + +def merge_helix_requests( + new_requests: List, + cp_rank: int, + cp_size: int, + tokens_per_block: int, + exclude_last_generation_logits: bool, +) -> List[LlmRequest]: + """Merge requests for Helix CP. + + Note: Helix parallelism is a decode-only feature run with disaggregated serving. + This function gets called on gen server during initialization of a new request. + + Args: + new_requests: List of RequestQueueItem objects. + cp_rank: Current CP rank. + cp_size: Total number of CP ranks. + tokens_per_block: Number of tokens per block. + exclude_last_generation_logits: Whether to exclude last generation logits. + + Returns: + List of LlmRequest objects including child requests. + """ + req_with_children = [] + + for req_item in new_requests: + input_ids_this_rank, position_ids_this_rank, input_len, _ = partition_context_for_helix( + req_item.request.input_token_ids, cp_rank, cp_size, tokens_per_block + ) + + req = executor_request_to_llm_request( + req_id=req_item.id, + executor_request=req_item.request, + child_req_ids=req_item.child_req_ids, + exclude_last_generation_logits=exclude_last_generation_logits, + input_token_ids=input_ids_this_rank, + position_ids=position_ids_this_rank, + ) + req.total_input_len_cp = input_len + req.seqlen_this_rank_cp = len(input_ids_this_rank) + req_with_children.append(req) + if req.child_requests: + req_with_children.extend(req.child_requests) + + return req_with_children + + +def merge_star_attention_requests( + new_requests: List, + cp_rank: int, + cp_size: int, + cp_config: dict, + exclude_last_generation_logits: bool, +) -> List[LlmRequest]: + """Merge requests for Star Attention CP. + + Args: + new_requests: List of RequestQueueItem objects. + cp_rank: Current CP rank. + cp_size: Total number of CP ranks. + cp_config: CP configuration dict containing 'block_size' and 'cp_anchor_size'. + exclude_last_generation_logits: Whether to exclude last generation logits. + + Returns: + List of LlmRequest objects. + """ + result = [] + block_size = cp_config["block_size"] + anchor_block_size = cp_config["cp_anchor_size"] + + for req_item in new_requests: + req_id, exe_req, query_token_ids = req_item.id, req_item.request, req_item.query + ctx_len0 = len(exe_req.input_token_ids) + + ctx_blocks, position_blocks, last_block_padding_num = partition_context_for_star_attention( + exe_req.input_token_ids, cp_rank, cp_size, block_size, anchor_block_size + ) + + if cp_rank == cp_size - 1 and last_block_padding_num > 0: + ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num] + position_blocks[-1] = position_blocks[-1][:-last_block_padding_num] + + # if has query + if query_token_ids: + ctx_blocks.append(query_token_ids) + position_blocks.append([i for i in range(ctx_len0, ctx_len0 + len(query_token_ids))]) + + # insert the dummy block to align the number of ctx iterations of each rank + total_blocks = (ctx_len0 + block_size - 1) // block_size + num_blocks_per_rank = (total_blocks + cp_size - 1) // cp_size + 1 # 1 for query block + if len(ctx_blocks) == num_blocks_per_rank: + ctx_blocks.insert(1, []) + position_blocks.insert(1, []) + elif len(ctx_blocks) == num_blocks_per_rank + 1: + # anchor + ctx_blocks + qry_block + pass + else: + raise ValueError( + f"Invalid context partition: rank = {cp_rank}, " + f"len(ctx_blocks) = {len(ctx_blocks)}, " + f"num_blocks_per_rank = {num_blocks_per_rank}" + ) + + # fake data for scheduler + ctx_blocks_list = [0] * (block_size + anchor_block_size) + + req = executor_request_to_llm_request( + req_id, exe_req, exclude_last_generation_logits, ctx_blocks_list + ) + req.gen_iters = 0 + req.ctx_iters = 0 + req.ctx_blocks = ctx_blocks + req.ctx_position_blocks = position_blocks + req.query_id = query_token_ids + + result.append(req) + + return result + + +@nvtx_range("merge_requests") +def merge_requests( + new_requests: List, + cp_config: dict, + cp_rank: int, + cp_size: int, + exclude_last_generation_logits: bool, +) -> List[LlmRequest]: + """Merge RequestQueueItems to LlmRequests based on CP configuration. + + This is a router function that dispatches to the appropriate merge function + based on the CP (Context Parallelism) configuration. + + Args: + new_requests: List of RequestQueueItem objects. + cp_config: CP configuration dict. May contain 'cp_type', 'tokens_per_block', + 'block_size', 'cp_anchor_size'. + cp_rank: Current CP rank. + cp_size: Total number of CP ranks. + exclude_last_generation_logits: Whether to exclude last generation logits. + + Returns: + List of LlmRequest objects. + + Raises: + NotImplementedError: If cp_type is not supported. + """ + if "cp_type" in cp_config: + cp_type = cp_config["cp_type"] + if cp_type == CpType.STAR: + return merge_star_attention_requests( + new_requests, + cp_rank=cp_rank, + cp_size=cp_size, + cp_config=cp_config, + exclude_last_generation_logits=exclude_last_generation_logits, + ) + elif cp_type == CpType.HELIX: + return merge_helix_requests( + new_requests, + cp_rank=cp_rank, + cp_size=cp_size, + tokens_per_block=cp_config["tokens_per_block"], + exclude_last_generation_logits=exclude_last_generation_logits, + ) + else: + raise NotImplementedError(f"Unsupported cp type {cp_type.name}.") + + return merge_requests_to_llm_requests(new_requests, exclude_last_generation_logits) + + +class RequestBroadcaster: + """Broadcasts requests across distributed ranks (TP, PP, CP).""" + + def __init__(self, dist: Distributed, hang_detector: HangDetector): + self.dist = dist + self.hang_detector = hang_detector + self.send_requests_handler = None + + def broadcast(self, new_requests: List) -> Tuple[List, Optional[Tuple]]: + """Broadcast requests and Python objects across ranks.""" + if self.dist.rank == 0: + py_request_objects = self._collect_py_objects(new_requests) + else: + py_request_objects = None + + if self.dist.rank == 0: + # Preserve original `new_requests` on rank 0 + _ = self._broadcast_requests(new_requests, py_request_objects) + else: + with self.hang_detector.pause(): + new_requests, py_request_objects = self._broadcast_requests( + new_requests, py_request_objects + ) + + return new_requests, py_request_objects + + def _collect_py_objects(self, new_requests: List) -> Tuple: + """Collect Python-only objects from requests.""" + py_logits_post_processors = collect_py_objects_from_requests( + new_requests, "py_logits_post_processors" + ) + py_multimodal_data = collect_py_objects_from_requests(new_requests, "py_multimodal_data") + py_scheduling_params = collect_py_objects_from_requests( + new_requests, "py_scheduling_params" + ) + py_num_logprobs = collect_py_objects_from_requests(new_requests, "py_num_logprobs") + py_disaggregated_params = collect_py_objects_from_requests( + new_requests, "py_disaggregated_params" + ) + + return tuple( + filter( + None, + [ + py_logits_post_processors, + py_multimodal_data, + py_scheduling_params, + py_num_logprobs, + py_disaggregated_params, + ], + ) + ) + + @nvtx_range("broadcast_requests") + def _broadcast_requests( + self, new_requests: List, py_request_objects + ) -> Tuple[List, Optional[Dict]]: + """Broadcast requests across pipeline stages.""" + payloads = (new_requests, py_request_objects) + + if not self.dist.has_pp: + return self.dist.broadcast(payloads, root=0) + + # Broadcast within first PP stage before send/recv chain to other PP stages. + # This needs to cover both TP and CP ranks within the first PP stage. + if self.dist.is_first_pp_rank: + payloads = self.dist.tp_cp_broadcast(payloads, root=0) + + # Tag for communication + tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts + + # Send payloads + if not self.dist.is_first_pp_rank: + with nvtx_range("recv_requests_from_prev_pp"): + payloads = self.dist.recv_object(self.dist.prev_pp_rank, tag) + + # isend new requests may cause deadlock, when CUDA_LAUNCH_BLOCKING=1 + # or PP microbatches can't overlap, the deadlock will happen: + # 1. rank1 will wait on nccl.send(rank2), without invoking mpi.wait(isend-handle) + # 2. rank2 will wait on mpi.recv(rank1) but never receive the new requests. + # 3. rank1 will hang on nccl.send because rank2 will never reach nccl.recv(rank1). + pp_send_func = ( + self.dist.isend_object + if os.environ.get("TRTLLM_PP_REQ_SEND_ASYNC", "0") == "1" + else self.dist.send_object + ) + + if not self.dist.is_last_pp_rank: + if self.send_requests_handler is not None: + with nvtx_range("wait_prev_send_requests_handler"): + self.send_requests_handler.wait() + with nvtx_range("send_requests_to_next_pp"): + self.send_requests_handler = pp_send_func(payloads, self.dist.next_pp_rank, tag) + + return payloads diff --git a/tests/unittest/_torch/executor/test_executor_request_queue.py b/tests/unittest/_torch/executor/test_executor_request_queue.py index a7645c93a0..6874d578f6 100644 --- a/tests/unittest/_torch/executor/test_executor_request_queue.py +++ b/tests/unittest/_torch/executor/test_executor_request_queue.py @@ -1,16 +1,22 @@ +"""Tests for ExecutorRequestQueue class. + +This module tests the ExecutorRequestQueue class functionality including: +- Queue initialization +- Request enqueuing (single, multiple, cancel, shutdown) +- Queue operations (get from request queue, timeout behavior) +- RequestQueueItem special types +""" + import datetime import queue import threading import time -from collections import deque from unittest.mock import Mock, patch import pytest from tensorrt_llm._torch.pyexecutor.executor_request_queue import ( SHUTDOWN_REQUEST_ID, ExecutorRequestQueue, RequestQueueItem) -from tensorrt_llm.bindings import executor as trtllm -from tensorrt_llm.mapping import CpType @pytest.fixture @@ -37,10 +43,7 @@ def mock_dist(): def executor_queue(mock_dist): """Create an ExecutorRequestQueue instance for testing.""" return ExecutorRequestQueue(dist=mock_dist, - enable_attention_dp=False, max_batch_size=8, - max_beam_width=1, - max_num_active_requests=16, enable_iter_perf_stats=True, batch_wait_timeout_ms=0.0) @@ -49,10 +52,7 @@ def executor_queue(mock_dist): def integration_queue(mock_dist): """Create an ExecutorRequestQueue instance for integration testing.""" return ExecutorRequestQueue(dist=mock_dist, - enable_attention_dp=True, max_batch_size=4, - max_beam_width=2, - max_num_active_requests=8, enable_iter_perf_stats=True, batch_wait_timeout_ms=0.0) @@ -60,15 +60,10 @@ def integration_queue(mock_dist): def test_executor_queue_init(executor_queue, mock_dist): """Test ExecutorRequestQueue initialization.""" assert executor_queue.dist == mock_dist - assert not executor_queue.enable_attention_dp - assert executor_queue.max_beam_width == 1 - assert executor_queue.max_num_active_requests == 16 assert executor_queue.next_request_id == 8 assert executor_queue.enable_iter_perf_stats assert executor_queue.active assert isinstance(executor_queue.request_queue, queue.Queue) - assert isinstance(executor_queue.waiting_queue, deque) - assert len(executor_queue.canceled_req_ids) == 0 assert isinstance(executor_queue.enqueue_lock, type(threading.Lock())) @@ -90,157 +85,6 @@ def test_enqueue_requests(executor_queue): assert executor_queue.start_times[req_id] == 1234.5 -def test_merge_helix_requests_with_padding(mock_dist): - """Test _merge_helix_requests with basic valid input.""" - - tokens_per_block = 2 - - # Create request item with 13 tokens to get exactly 7 blocks for 4 CP ranks. - input_tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] - executor_request = trtllm.Request(input_token_ids=input_tokens, - max_tokens=5, - streaming=False, - sampling_config=trtllm.SamplingConfig(), - output_config=trtllm.OutputConfig()) - request_item = RequestQueueItem( - id=1, - request=executor_request, - ) - - for rank in [0, 1, 2, 3]: - # Create executor queue for helix with 4 CP ranks. - mock_dist.cp_size = 4 - mock_dist.cp_rank = rank - mock_dist.cp_config = { - 'cp_type': CpType.HELIX, - 'tokens_per_block': tokens_per_block, - } - executor_queue = ExecutorRequestQueue(dist=mock_dist, - enable_attention_dp=False, - max_batch_size=8, - max_beam_width=1, - max_num_active_requests=16, - enable_iter_perf_stats=True, - batch_wait_timeout_ms=0.0) - - # Mock _should_exclude_last_generation_logits. - with patch.object(executor_queue, - '_should_exclude_last_generation_logits', - return_value=False): - result = executor_queue._merge_helix_requests([request_item], - tokens_per_block) - - # Verify the result. - assert len(result) == 1 - llm_request = result[0] - from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest - assert isinstance(llm_request, LlmRequest) - assert llm_request.request_id == 1 - if rank == 0: - assert llm_request.get_tokens(0) == [1, 2, 3, 4] - elif rank == 1: - assert llm_request.get_tokens(0) == [5, 6, 7, 8] - elif rank == 2: - assert llm_request.get_tokens(0) == [9, 10, 11, 12] - else: - assert llm_request.get_tokens(0) == [13] - - -def test_merge_helix_requests_without_padding(mock_dist): - """Test _merge_helix_requests with evenly divisible tokens (no padding).""" - - tokens_per_block = 4 - - # Create request item with 12 tokens to get exactly 3 blocks for 2 CP ranks. - input_tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] - executor_request = trtllm.Request(input_token_ids=input_tokens, - max_tokens=5, - streaming=False, - sampling_config=trtllm.SamplingConfig(), - output_config=trtllm.OutputConfig()) - request_item = RequestQueueItem( - id=1, - request=executor_request, - ) - - for rank in [0, 1]: - # Create executor queue for helix with 2 CP ranks. - mock_dist.cp_size = 2 - mock_dist.cp_rank = rank - mock_dist.cp_config = { - 'cp_type': CpType.HELIX, - 'tokens_per_block': tokens_per_block, - } - executor_queue = ExecutorRequestQueue(dist=mock_dist, - enable_attention_dp=False, - max_batch_size=8, - max_beam_width=1, - max_num_active_requests=16, - enable_iter_perf_stats=True, - batch_wait_timeout_ms=0.0) - - # Mock _should_exclude_last_generation_logits. - with patch.object(executor_queue, - '_should_exclude_last_generation_logits', - return_value=False): - result = executor_queue._merge_helix_requests([request_item], - tokens_per_block) - - # Verify the result. - assert len(result) == 1 - llm_request = result[0] - from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest - assert isinstance(llm_request, LlmRequest) - assert llm_request.request_id == 1 - if rank == 0: - assert llm_request.get_tokens(0) == [1, 2, 3, 4, 5, 6, 7, 8] - else: - assert llm_request.get_tokens(0) == [9, 10, 11, 12] - - -def test_merge_helix_requests_insufficient_blocks_error(mock_dist): - """Test _merge_helix_requests raises error when insufficient blocks.""" - mock_dist.cp_size = 4 - - tokens_per_block = 4 - mock_dist.cp_config = { - 'cp_type': CpType.HELIX, - 'tokens_per_block': tokens_per_block, - } - - # Create input with only 12 tokens. This creates 3 blocks which is fewer than 4 CP ranks. - executor_request = trtllm.Request( - input_token_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], - max_tokens=12, - streaming=False, - sampling_config=trtllm.SamplingConfig(), - output_config=trtllm.OutputConfig()) - request_item = RequestQueueItem( - id=1, - request=executor_request, - ) - - # Loop over ranks 0, 1, 2, 3 and verify that all ranks throw assertion. - for rank in range(4): - mock_dist.cp_rank = rank - - executor_queue = ExecutorRequestQueue(dist=mock_dist, - enable_attention_dp=False, - max_batch_size=8, - max_beam_width=1, - max_num_active_requests=16, - enable_iter_perf_stats=True, - batch_wait_timeout_ms=0.0) - - with pytest.raises( - ValueError, - match= - "There aren't enough tokens to get at least one block per CP rank" - ): - executor_queue._merge_helix_requests([request_item], - tokens_per_block) - - def test_enqueue_request_single(executor_queue): """Test enqueuing a single request.""" mock_request = Mock() @@ -274,8 +118,9 @@ def test_enqueue_request_with_child_ids(executor_queue, n_children): """Test enqueuing a request with query data.""" mock_request = Mock() query_data = [1, 2, 3, 4] - with patch.object(executor_queue, - '_get_num_child_requests') as mock_children: + with patch( + 'tensorrt_llm._torch.pyexecutor.executor_request_queue.get_num_child_requests' + ) as mock_children: mock_children.return_value = n_children req_id = executor_queue.enqueue_request(mock_request, query=query_data) @@ -347,7 +192,7 @@ def test_get_from_request_queue_no_timeout(executor_queue): executor_queue.request_queue.put(item1) executor_queue.request_queue.put(item2) - items = executor_queue._get_from_request_queue(None) + items = executor_queue.get_from_request_queue(None) assert len(items) == 2 assert items[0] == item1 @@ -360,7 +205,7 @@ def test_get_from_request_queue_with_timeout(executor_queue): # Empty queue should return empty list quickly start_time = time.time() - items = executor_queue._get_from_request_queue(timeout) + items = executor_queue.get_from_request_queue(timeout) elapsed = time.time() - start_time assert len(items) == 0 @@ -391,7 +236,7 @@ def test_get_from_request_queue_async_behavior(executor_queue): # Get requests immediately - should only get the initial ones start_time = time.time() - items = executor_queue._get_from_request_queue(None) + items = executor_queue.get_from_request_queue(None) elapsed = time.time() - start_time assert len(items) == initial_requests @@ -420,7 +265,7 @@ def test_get_from_request_queue_async_behavior(executor_queue): # Get requests with batch_wait_timeout_ms - should wait and get all start_time = time.time() - items = executor_queue._get_from_request_queue(None) + items = executor_queue.get_from_request_queue(None) elapsed = time.time() - start_time # Should wait and return all requests @@ -436,47 +281,8 @@ def test_get_from_request_queue_async_behavior(executor_queue): thread.join() -def test_get_from_waiting_queue(executor_queue): - """Test getting items from waiting queue.""" - # Add items to waiting queue - items = [RequestQueueItem(i, Mock()) for i in range(5)] - executor_queue.waiting_queue.extend(items) - - # Get 3 items - result = executor_queue._get_from_waiting_queue( - executor_queue.waiting_queue, 3, enable_attention_dp=False) - - assert len(result) == 3 - assert result == items[:3] - assert len(executor_queue.waiting_queue) == 2 - - -@pytest.mark.parametrize( - "queue_size,request_count,expected_result,expected_remaining", - [ - (0, 5, 0, 0), # Empty queue - (3, -1, 0, 3), # Negative count - (3, 0, 0, 3), # Zero count - (3, 10, 3, 0), # Request more than available - ]) -def test_get_from_waiting_queue_edge_cases(executor_queue, queue_size, - request_count, expected_result, - expected_remaining): - """Test edge cases for getting items from waiting queue.""" - # Setup queue - if queue_size > 0: - items = [RequestQueueItem(i, Mock()) for i in range(queue_size)] - executor_queue.waiting_queue.extend(items) - - result = executor_queue._get_from_waiting_queue( - executor_queue.waiting_queue, request_count, enable_attention_dp=False) - - assert len(result) == expected_result - assert len(executor_queue.waiting_queue) == expected_remaining - - -def test_handle_special_queue_items(executor_queue): - """Test special queue item handling.""" +def test_request_queue_item_special_types(): + """Test RequestQueueItem special type detection.""" # Create a mock request without sampling_config to avoid beam validation mock_request = Mock() delattr(mock_request, 'sampling_config') if hasattr( @@ -486,881 +292,27 @@ def test_handle_special_queue_items(executor_queue): cancel_req = RequestQueueItem(2, is_canceled_request=True) shutdown_req = RequestQueueItem(SHUTDOWN_REQUEST_ID) - requests = [normal_req, cancel_req, shutdown_req] + # Test normal request + assert normal_req.is_normal_request + assert not normal_req.is_shutdown_request + assert not normal_req.is_canceled_request - valid_requests = executor_queue._handle_special_queue_items(requests) + # Test cancel request + assert cancel_req.is_canceled_request + assert not cancel_req.is_shutdown_request + assert not cancel_req.is_normal_request - assert len(valid_requests) == 1 - assert valid_requests[0] == normal_req - assert executor_queue.is_shutdown - assert 2 in executor_queue.canceled_req_ids + # Test shutdown request + assert shutdown_req.is_shutdown_request + assert not shutdown_req.is_canceled_request + assert not shutdown_req.is_normal_request -@patch( - 'tensorrt_llm._torch.pyexecutor.executor_request_queue.executor_request_to_llm_request' -) -def test_merge_requests_default(mock_convert, executor_queue): - """Test merging requests with default configuration.""" - mock_llm_request = Mock(child_requests=[]) - mock_convert.return_value = mock_llm_request - - requests = [RequestQueueItem(1, Mock()), RequestQueueItem(2, Mock())] - result = executor_queue._merge_requests(requests) - - assert len(result) == 2 - assert mock_convert.call_count == 2 - - -def test_update_waiting_queue(executor_queue): - """Test updating waiting queue to remove canceled requests.""" - items = [ - RequestQueueItem(1, Mock()), - RequestQueueItem(2, Mock()), - RequestQueueItem(3, Mock()), - ] - executor_queue.waiting_queue.extend(items) - executor_queue.canceled_req_ids = [2] - - executor_queue.update_waiting_queue() - - assert len(executor_queue.waiting_queue) == 2 - remaining_ids = [item.id for item in executor_queue.waiting_queue] - assert 1 in remaining_ids - assert 3 in remaining_ids - assert 2 not in remaining_ids - - -def test_performance_metrics_methods(executor_queue): - """Test various performance metrics getter methods.""" +def test_queue_size_methods(executor_queue): + """Test queue size getter methods.""" # Test initial values - assert executor_queue.get_new_active_requests_queue_latency() == 0 - assert executor_queue.get_expected_num_active_requests() == 0 assert executor_queue.get_request_queue_size() == 0 - assert executor_queue.get_waiting_queue_size() == 0 - assert executor_queue.get_canceled_req_ids_size() == 0 - assert executor_queue.get_canceled_req_ids() == [] # Add some data and test executor_queue.request_queue.put(RequestQueueItem(1, Mock())) - executor_queue.waiting_queue.append(RequestQueueItem(2, Mock())) - executor_queue.canceled_req_ids = [3, 4] - executor_queue.expected_num_active_requests = 5 - assert executor_queue.get_request_queue_size() == 1 - assert executor_queue.get_waiting_queue_size() == 1 - assert executor_queue.get_canceled_req_ids_size() == 2 - assert executor_queue.get_canceled_req_ids() == [3, 4] - assert executor_queue.get_expected_num_active_requests() == 5 - - -def test_clear_canceled_req_ids(executor_queue): - """Test clearing canceled request IDs.""" - executor_queue.canceled_req_ids = [1, 2, 3] - assert len(executor_queue.canceled_req_ids) == 3 - - executor_queue.clear_canceled_req_ids() - - assert len(executor_queue.canceled_req_ids) == 0 - - -@pytest.fixture -def mock_dist_attention_dp(): - """Create a mock Distributed instance for testing.""" - mock_dist = Mock() - mock_dist.rank = 0 - mock_dist.tp_size = 4 - mock_dist.pp_size = 1 - mock_dist.has_pp = False - mock_dist.tp_rank = 0 - mock_dist.cp_rank = 0 - mock_dist.cp_size = 1 - mock_dist.cp_config = {} - mock_dist.is_first_pp_rank = True - mock_dist.is_last_pp_rank = True - mock_dist.next_pp_rank = 1 - mock_dist.prev_pp_rank = 0 - mock_dist.broadcast = Mock(return_value=([], None)) - return mock_dist - - -@pytest.fixture -def attention_dp_queue(mock_dist_attention_dp): - """Create an ExecutorRequestQueue instance for attention DP testing.""" - queue = ExecutorRequestQueue(dist=mock_dist_attention_dp, - enable_attention_dp=True, - max_batch_size=4, - max_beam_width=2, - max_num_active_requests=8, - enable_iter_perf_stats=True, - batch_wait_timeout_ms=0.0) - # Initialize all_ranks_num_active_requests - return queue - - -@pytest.fixture -def all_ranks_num_active_requests(): - return [2, 1, 3, 0] # 4 ranks - - -@pytest.fixture -def all_ranks_num_active_tokens(): - return [10, 5, 15, 8] # 4 ranks - - -def create_mock_request_with_py_schedule_params(attention_dp_rank=None, - attention_dp_relax=False): - mock_request = Mock() - - if attention_dp_rank is not None: - mock_schedule_params = Mock() - mock_schedule_params.attention_dp_rank = attention_dp_rank - mock_schedule_params.attention_dp_relax = attention_dp_relax - - mock_schedule_params.configure_mock( - attention_dp_rank=attention_dp_rank, - attention_dp_relax=attention_dp_relax) - - mock_request.py_scheduling_params = mock_schedule_params - else: - mock_request.py_scheduling_params = None - - mock_request.input_token_ids = [1, 2, 3] - - return mock_request - - -# Unit tests for _schedule_attention_dp_requests -def test_schedule_attention_dp_requests_scheduled_requests( - attention_dp_queue, all_ranks_num_active_requests, - all_ranks_num_active_tokens): - req1 = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=False)) - req2 = RequestQueueItem( - 2, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=False)) - - new_requests = [req1, req2] - - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - result = all_ranks_new_requests[0] - - assert len(result) == 2 - assert req1 in result - assert req2 in result - - assert all_ranks_num_active_requests[0] == 4 - - -def test_schedule_attention_dp_requests_scheduled_requests_other_ranks( - attention_dp_queue, all_ranks_num_active_requests, - all_ranks_num_active_tokens): - req1 = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=1, - attention_dp_relax=False)) - req2 = RequestQueueItem( - 2, - create_mock_request_with_py_schedule_params(attention_dp_rank=2, - attention_dp_relax=False)) - - new_requests = [req1, req2] - - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - - result = all_ranks_new_requests[0] - assert len(result) == 0 - - assert all_ranks_num_active_requests[1] == 2 - assert all_ranks_num_active_requests[2] == 4 - - -def test_schedule_attention_dp_requests_unscheduled_requests( - attention_dp_queue, all_ranks_num_active_requests, - all_ranks_num_active_tokens): - req1 = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=True)) - req2 = RequestQueueItem( - 2, - create_mock_request_with_py_schedule_params(attention_dp_rank=1, - attention_dp_relax=True)) - - new_requests = [req1, req2] - - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - result = all_ranks_new_requests[0] - - assert len(result) == 1 # Only req1 for current rank - assert req1 in result - - -def test_schedule_attention_dp_requests_unscheduled_no_capacity( - attention_dp_queue, all_ranks_num_active_requests, - all_ranks_num_active_tokens): - all_ranks_num_active_requests[0] = 8 - - req1 = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=True)) - - new_requests = [req1] - - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - result = all_ranks_new_requests[0] - - assert len(result) == 0 # No capacity - - -def test_schedule_attention_dp_requests_mixed_scenarios( - attention_dp_queue, all_ranks_num_active_requests, - all_ranks_num_active_tokens): - req_scheduled_current = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=False)) - req_scheduled_other = RequestQueueItem( - 2, - create_mock_request_with_py_schedule_params(attention_dp_rank=1, - attention_dp_relax=False)) - req_unscheduled_current = RequestQueueItem( - 3, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=True)) - req_unscheduled_other = RequestQueueItem( - 4, - create_mock_request_with_py_schedule_params(attention_dp_rank=2, - attention_dp_relax=True)) - - new_requests = [ - req_scheduled_current, req_scheduled_other, req_unscheduled_current, - req_unscheduled_other - ] - - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - result = all_ranks_new_requests[0] - - assert len(result) == 2 - assert req_scheduled_current in result - assert req_unscheduled_current in result - - -def test_schedule_attention_dp_requests_empty_lists( - attention_dp_queue, all_ranks_num_active_requests, - all_ranks_num_active_tokens): - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - [], all_ranks_num_active_requests, all_ranks_num_active_tokens) - result = all_ranks_new_requests[0] - - assert len(result) == 0 - - -def test_schedule_attention_dp_requests_expected_num_active_calculation( - attention_dp_queue, all_ranks_num_active_requests, - all_ranks_num_active_tokens): - req1 = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=True)) - req2 = RequestQueueItem( - 2, - create_mock_request_with_py_schedule_params(attention_dp_rank=1, - attention_dp_relax=True)) - - new_requests = [req1, req2] - - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - all_ranks_new_requests[0] - - # 2 + 1 + 3 + 0 = 6, 6 + 2 = 8, (8 + 3) // 4 = 2, max(2, 2, 1, 3, 0) = 3 - # expected_num_active_requests = max((6 + 2 + 3) // 4, 3) = max(2, 3) = 3 - assert attention_dp_queue.expected_num_active_requests == 3 - - -def test_schedule_attention_dp_requests_balance_requests_called( - attention_dp_queue, all_ranks_num_active_requests, - all_ranks_num_active_tokens): - req1 = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=True)) - - new_requests = [req1] - - with patch.object(attention_dp_queue, - '_balance_requests_across_ranks') as mock_balance: - mock_balance.return_value = {0: req1} - - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - all_ranks_new_requests[0] - - # Check that _balance_requests_across_ranks was called - mock_balance.assert_called_once() - call_args = mock_balance.call_args[0] - assert isinstance(call_args[0], list) - assert isinstance(call_args[1], dict) - assert call_args[2] == all_ranks_num_active_requests # Third arg - assert call_args[3] == all_ranks_num_active_tokens # Fourth arg - - -def test_schedule_attention_dp_requests_no_scheduling_when_capacity_exceeded( - attention_dp_queue, all_ranks_num_active_requests, - all_ranks_num_active_tokens): - all_ranks_num_active_requests[0] = 8 - - req1 = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=False)) - - new_requests = [req1] - - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - result = all_ranks_new_requests[0] - - assert len(result) == 0 # No requests scheduled - assert all_ranks_num_active_requests[0] == 8 # Capacity unchanged - - -# Integration tests combining both methods -def test_filter_and_schedule_integration(attention_dp_queue, - all_ranks_num_active_requests, - all_ranks_num_active_tokens): - req_schedulable = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=False)) - req_schedulable.request.input_token_ids = [1, 2, 3, 4] - req_relax = RequestQueueItem( - 2, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=True)) - req_relax.request.input_token_ids = [1, 2] - - req_no_params = RequestQueueItem( - 3, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - - new_requests = [req_schedulable, req_relax, req_no_params] - - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - result = all_ranks_new_requests[0] - - assert len(result) == 2 - assert req_schedulable in result - assert req_relax in result - - -def test_filter_and_schedule_with_capacity_limits(attention_dp_queue, - all_ranks_num_active_requests, - all_ranks_num_active_tokens): - all_ranks_num_active_requests[0] = 7 - - req1 = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=False)) - req1.request.input_token_ids = [1, 2, 3, 4] - req2 = RequestQueueItem( - 2, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=False)) - req2.request.input_token_ids = [1, 2, 3] - - new_requests = [req1, req2] - - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - result = all_ranks_new_requests[0] - - assert len(result) == 1 - assert req1 in result - - -def test_get_from_waiting_queue_with_attention_dp( - attention_dp_queue, all_ranks_num_active_requests): - items = [RequestQueueItem(i, Mock()) for i in range(5)] - attention_dp_queue.waiting_queue.extend(items) - - result = attention_dp_queue._get_from_waiting_queue( - attention_dp_queue.waiting_queue, 3, True, - all_ranks_num_active_requests) - - assert len(result) == 3 - assert result == items[:3] - assert len(attention_dp_queue.waiting_queue) == 2 - - -def test_get_from_waiting_queue_with_attention_dp_filtering( - attention_dp_queue, all_ranks_num_active_requests): - req1 = RequestQueueItem( - 1, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=False)) - req2 = RequestQueueItem( - 2, - create_mock_request_with_py_schedule_params(attention_dp_rank=1, - attention_dp_relax=True)) - req3 = RequestQueueItem(3, - create_mock_request_with_py_schedule_params( - attention_dp_rank=None)) # No scheduling params - - attention_dp_queue.waiting_queue.extend([req1, req2, req3]) - - # Set rank 0 to full capacity to test filtering - all_ranks_num_active_requests[0] = 8 - - result = attention_dp_queue._get_from_waiting_queue( - attention_dp_queue.waiting_queue, 3, True, - all_ranks_num_active_requests) - - assert len(result) == 2 - assert req2 in result - assert req3 in result - assert req1 not in result - - -def test_can_process_attention_dp_request(attention_dp_queue): - req_no_params = RequestQueueItem(1, Mock()) - assert attention_dp_queue._can_process_attention_dp_request( - req_no_params, [0, 0, 0, 0]) == True - - req_relax = RequestQueueItem( - 2, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=True)) - assert attention_dp_queue._can_process_attention_dp_request( - req_relax, [0, 0, 0, 0]) == True - - req_target = RequestQueueItem( - 3, - create_mock_request_with_py_schedule_params(attention_dp_rank=1, - attention_dp_relax=False)) - all_ranks = [0, 0, 0, 0] - assert attention_dp_queue._can_process_attention_dp_request( - req_target, all_ranks) == True - assert all_ranks[1] == 1 - - req_no_capacity = RequestQueueItem( - 4, - create_mock_request_with_py_schedule_params(attention_dp_rank=0, - attention_dp_relax=False)) - all_ranks_full = [8, 0, 0, 0] # Rank 0 is at capacity - assert attention_dp_queue._can_process_attention_dp_request( - req_no_capacity, all_ranks_full) == False - - -def test_achieve_max_num_active_requests(attention_dp_queue): - req_list = [] - req_id = 0 - for rank in range(4): - for i in range(5): - req_list.append( - RequestQueueItem( - req_id, - create_mock_request_with_py_schedule_params( - attention_dp_rank=rank, attention_dp_relax=False))) - req_id += 1 - req_list.append( - RequestQueueItem( - req_id, - create_mock_request_with_py_schedule_params( - attention_dp_rank=rank, attention_dp_relax=True))) - req_id += 1 - - all_ranks_num_active_requests = [5, 6, 3, 7] - attention_dp_queue.waiting_queue.extend(req_list) - available_active_requests = attention_dp_queue.max_num_active_requests * 4 - sum( - all_ranks_num_active_requests) - - result = attention_dp_queue._get_from_waiting_queue( - attention_dp_queue.waiting_queue, available_active_requests, True, - all_ranks_num_active_requests) - - assert len(result) == available_active_requests - - -def append_to_waiting_queue(waiting_queue, rank, attention_dp_relax): - req_id = len(waiting_queue) - waiting_queue.append( - RequestQueueItem( - req_id, - create_mock_request_with_py_schedule_params( - attention_dp_rank=rank, attention_dp_relax=attention_dp_relax))) - - -@pytest.mark.parametrize( - "max_num_active_requests,all_ranks_num_active_requests,request_configs,all_ranks_expected_req_ids", - [ - # Case: Balanced distribution of relaxed requests - ( - 3, - [0, 0, 0, 0], - [(None, True)] * 7, - { - 0: [0, 1], # First 2 requests go to rank 0 - 1: [2, 3], # Next 2 requests go to rank 1 - 2: [4, 5], # Next 2 requests go to rank 2 - 3: [6] # Last request goes to rank 3 - }), - # Case: Balanced distribution of relaxed requests with existing load - ( - 3, - [1, 2, 3, 0], - [(None, True)] * 13, - { - 0: [0, 1], # Rank 0 gets first 2 requests - 1: [2], # Rank 1 gets 1 request (already has 2) - 2: [], # Rank 2 is at capacity (3) - 3: [3, 4, 5] # Rank 3 gets 3 requests (starts with 0) - }), - # Case: Limited by max active - ( - 3, - [0, 0, 0, 0], - [(None, True)] * 13, - { - 0: [0, 1, 3], # First 3 requests (0, 1, 3) - 1: [2, 4, 6], # Next 3 requests (2, 4, 6) - 2: [5, 7, 9], # Next 3 requests (5, 7, 9) - 3: [8, 10, 11] # Last 3 requests (8, 10, 11) - }), - # Case: Empty new requests - (3, [3, 3, 3, 0], [], { - 0: [], - 1: [], - 2: [], - 3: [] - }), - # Case: Rank 0 is full and cannot schedule attention_dp rank request - ( - 3, - [3, 1, 3, 0], - [(0, False), (0, True)], - { - 0: [], # Rank 0 is full - 1: [1], # Rank 1 gets the relaxed request (req1) - 2: [], # No relaxed requests assigned here - 3: [] # No relaxed requests assigned here - }), - # Case: Only room for 1 request, need to skip req0 with attention dp rank - ( - 3, - [3, 2, 3, 3], - [(0, False), (0, True)], - { - 0: [], # Rank 0 is full - 1: [1], # Rank 1 gets the relaxed request - 2: [], # Rank 2 is at capacity - 3: [] # Rank 3 is at capacity - }), - # Case: Targeting ranks 1 and 3 that have room - ( - 3, - [2, 1, 3, 0], - [(1, False), (3, False)], - { - 0: [], # No requests assigned to rank 0 - 1: [0], # Request 0 targets rank 1 - 2: [], # No requests assigned to rank 2 - 3: [1] # Request 1 targets rank 3 - }), - # Case: Target dp rank specified, but relax is True - ( - 3, - [3, 3, 3, 1], - [(0, True), (1, True), (2, True)], - { - 0: [], # Rank 0 is at capacity - 1: [], # Rank 1 is at capacity - 2: [], # Rank 2 is at capacity - 3: [0, 1] # Rank 3 gets both relaxed requests - }), - # Case: Mixed targeting and relaxed - ( - 3, - [3, 3, 3, 0], - [(0, False), (1, True), (3, False)], - { - 0: [], # Rank 0 is at capacity - 1: [], # Rank 1 is at capacity - 2: [], # Rank 2 is at capacity - 3: [2, 1] # Rank 3 gets both requests (targeted + relaxed) - }), - ]) -def test_attention_dp_scheduling_cases(attention_dp_queue, - max_num_active_requests, - all_ranks_num_active_requests, - request_configs, - all_ranks_expected_req_ids): - """Test attention DP scheduling with various scenarios.""" - attention_dp_queue.max_num_active_requests = max_num_active_requests - - waiting_queue = deque() - for rank, relax in request_configs: - append_to_waiting_queue(waiting_queue, rank, relax) - - run_test_attention_dp_scheduling(attention_dp_queue, waiting_queue, - all_ranks_num_active_requests, - all_ranks_expected_req_ids) - - -def run_test_attention_dp_scheduling(attention_dp_queue, waiting_queue, - all_ranks_num_active_requests, - all_ranks_expected_req_ids): - - num_ranks = len(all_ranks_num_active_requests) - total_num_active_requests = sum(all_ranks_num_active_requests) - total_max_num_active_requests = attention_dp_queue.max_num_active_requests * num_ranks - enable_attention_dp = True - - new_requests = attention_dp_queue._get_from_waiting_queue( - waiting_queue, - total_max_num_active_requests - total_num_active_requests, - enable_attention_dp, all_ranks_num_active_requests) - - # Create mock token counts for testing - all_ranks_num_active_tokens = [10 + i * 5 for i in range(num_ranks)] - - # Schedule attention dp requests - all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests( - new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - - assert len(all_ranks_new_requests) == num_ranks - print("all_ranks_new_requests:", all_ranks_new_requests) - for rank, reqs in all_ranks_new_requests.items(): - req_ids = [req.id for req in reqs] - assert req_ids == all_ranks_expected_req_ids[rank] - - -# New tests for _balance_requests_across_ranks method -def test_balance_requests_across_ranks_empty_requests(attention_dp_queue): - """Test _balance_requests_across_ranks with empty requests list.""" - all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} - all_ranks_num_active_requests = [2, 1, 3, 0] - all_ranks_num_active_tokens = [20, 10, 30, 5] - - # Set expected_num_active_requests for testing - attention_dp_queue.expected_num_active_requests = 3 - - result = attention_dp_queue._balance_requests_across_ranks( - [], all_ranks_new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - - # Should return the original structure unchanged - assert result == all_ranks_new_requests - for rank in range(4): - assert len(result[rank]) == 0 - - -def test_balance_requests_across_ranks_single_request(attention_dp_queue): - """Test _balance_requests_across_ranks with a single request.""" - req = RequestQueueItem( - 1, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req.request.input_token_ids = [1, 2, 3, 4, 5] # 5 tokens - - all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} - all_ranks_num_active_requests = [1, 2, 0, 1] # Rank 2 has lowest count - all_ranks_num_active_tokens = [10, 20, 5, 15] - - # Set expected_num_active_requests for testing - attention_dp_queue.expected_num_active_requests = 2 - - result = attention_dp_queue._balance_requests_across_ranks( - [req], all_ranks_new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - - # Request should be assigned to rank 2 (lowest active count) - assert len(result[0]) == 0 - assert len(result[1]) == 0 - assert len(result[2]) == 1 - assert len(result[3]) == 0 - assert result[2][0] == req - - -def test_balance_requests_across_ranks_multiple_requests(attention_dp_queue): - """Test _balance_requests_across_ranks with multiple requests.""" - # Create requests with different token counts - req1 = RequestQueueItem( - 1, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req1.request.input_token_ids = [1, 2, 3] # 3 tokens - - req2 = RequestQueueItem( - 2, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req2.request.input_token_ids = [1, 2, 3, 4, 5, 6] # 6 tokens - - req3 = RequestQueueItem( - 3, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req3.request.input_token_ids = [1, 2] # 2 tokens - - all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} - all_ranks_num_active_requests = [0, 1, 2, 1] - all_ranks_num_active_tokens = [5, 15, 25, 10] - - # Set expected_num_active_requests for testing - attention_dp_queue.expected_num_active_requests = 2 - - result = attention_dp_queue._balance_requests_across_ranks( - [req1, req2, req3], all_ranks_new_requests, - all_ranks_num_active_requests, all_ranks_num_active_tokens) - - # Requests should be distributed based on heap (lowest active count first) - # Requests are sorted by token count (descending) first, then assigned to ranks with lowest active count - # req2 (6 tokens) -> rank 0 (0 active) -> total: 1 active, 11 tokens - # req3 (2 tokens) -> rank 0 (1 active) -> total: 2 active, 13 tokens (rank 0 still has capacity) - # req1 (3 tokens) -> rank 3 (1 active) -> total: 2 active, 13 tokens - # Rank 1: 1 active, gets nothing (rank 0 took 2 requests) - # Rank 2: 2 active, gets nothing (at capacity) - - assert len(result[0]) == 2 # req2 and req3 (rank 0 has capacity for 2) - assert len(result[1]) == 0 # no requests (rank 0 took 2 requests) - assert len(result[2]) == 0 # at capacity - assert len(result[3]) == 1 # req1 - - # Verify the requests are assigned correctly - assert result[0][0] == req2 # First request (highest token count) - assert result[0][1] == req3 # Second request - assert result[3][0] == req1 - - -def test_balance_requests_across_ranks_capacity_limits(attention_dp_queue): - """Test _balance_requests_across_ranks respects capacity limits.""" - # Create multiple requests - requests = [] - for i in range(4): - req = RequestQueueItem( - i, - create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req.request.input_token_ids = [1] * (i + 1) # Variable token counts - requests.append(req) - - all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} - all_ranks_num_active_requests = [1, 1, 1, 1] # All ranks start with 1 - all_ranks_num_active_tokens = [10, 10, 10, 10] - - # Set expected_num_active_requests to limit capacity - attention_dp_queue.expected_num_active_requests = 2 - - result = attention_dp_queue._balance_requests_across_ranks( - requests, all_ranks_new_requests, all_ranks_num_active_requests, - all_ranks_num_active_tokens) - - # Each rank can only take 1 more request (1 + 1 = 2, which equals expected_num_active_requests) - total_assigned = sum( - len(rank_requests) for rank_requests in result.values()) - assert total_assigned == 4 # 4 ranks × 1 additional request each - - # Verify no rank exceeds capacity - for rank in range(4): - assert len(result[rank]) <= 1 - - -def test_balance_requests_across_ranks_heap_ordering(attention_dp_queue): - """Test that _balance_requests_across_ranks uses heap ordering correctly.""" - # Create requests with same token count to test heap ordering - req1 = RequestQueueItem( - 1, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req1.request.input_token_ids = [1, 2, 3] # 3 tokens - - req2 = RequestQueueItem( - 2, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req2.request.input_token_ids = [1, 2, 3] # 3 tokens - - req3 = RequestQueueItem( - 3, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req3.request.input_token_ids = [1, 2, 3] # 3 tokens - - all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} - # Rank 0 has highest active count, should get requests last - all_ranks_num_active_requests = [3, 1, 0, 2] - all_ranks_num_active_tokens = [30, 10, 5, 20] - - # Set expected_num_active_requests for testing - attention_dp_queue.expected_num_active_requests = 4 - - result = attention_dp_queue._balance_requests_across_ranks( - [req1, req2, req3], all_ranks_new_requests, - all_ranks_num_active_requests, all_ranks_num_active_tokens) - - # Requests should be assigned in order of lowest active count first - # Since all requests have same token count, they're assigned based on active count order - # Rank 2: 0 active -> gets req1 and req2 (has capacity for 2) - # Rank 1: 1 active -> gets req3 (after rank 2 takes 2) - # Rank 3: 2 active -> gets nothing (rank 1 took req3) - # Rank 0: 3 active -> gets nothing (at capacity) - - assert len(result[0]) == 0 # at capacity - assert len(result[1]) == 1 # req3 - assert len(result[2]) == 2 # req1 and req2 - assert len(result[3]) == 0 # no requests - - # Verify the requests are assigned correctly - assert result[1][0] == req3 # Third request - assert result[2][0] == req1 # First request - assert result[2][1] == req2 # Second request - - -def test_balance_requests_across_ranks_token_count_sorting(attention_dp_queue): - """Test that requests are sorted by token count before distribution.""" - # Create requests with different token counts - req1 = RequestQueueItem( - 1, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req1.request.input_token_ids = [1] # 1 token (smallest) - - req2 = RequestQueueItem( - 2, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req2.request.input_token_ids = [1, 2, 3, 4, 5] # 5 tokens (largest) - - req3 = RequestQueueItem( - 3, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) - req3.request.input_token_ids = [1, 2, 3] # 3 tokens (medium) - - all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} - all_ranks_num_active_requests = [0, 0, 0, 0] # All ranks start empty - all_ranks_num_active_tokens = [5, 5, 5, 5] - - # Set expected_num_active_requests for testing - attention_dp_queue.expected_num_active_requests = 2 - - result = attention_dp_queue._balance_requests_across_ranks( - [req1, req2, req3], all_ranks_new_requests, - all_ranks_num_active_requests, all_ranks_num_active_tokens) - - # Requests should be sorted by token count (descending) before distribution - # Then assigned to ranks with lowest active count first - # req2 (5 tokens) -> rank 0 (0 active) - # req3 (3 tokens) -> rank 1 (0 active) - # req1 (1 token) -> rank 2 (0 active) - - assert len(result[0]) == 1 # req2 (highest token count) - assert len(result[1]) == 1 # req3 - assert len(result[2]) == 1 # req1 (lowest token count) - assert len(result[3]) == 0 - - # Verify the requests are assigned correctly - assert result[0][0] == req2 - assert result[1][0] == req3 - assert result[2][0] == req1 diff --git a/tests/unittest/_torch/executor/test_py_executor.py b/tests/unittest/_torch/executor/test_py_executor.py new file mode 100644 index 0000000000..69fbc059e4 --- /dev/null +++ b/tests/unittest/_torch/executor/test_py_executor.py @@ -0,0 +1,182 @@ +"""Tests for PyExecutor request handling functionality. + +This module tests the request handling logic that was moved from ExecutorRequestQueue +to PyExecutor, including: +- _handle_special_queue_items method +- canceled_req_ids management +- waiting_queue management +- is_shutdown state management +- expected_num_active_requests tracking +""" + +from collections import deque +from unittest.mock import Mock + +import pytest + +from tensorrt_llm._torch.pyexecutor.executor_request_queue import ( + SHUTDOWN_REQUEST_ID, + RequestQueueItem, +) + + +class MockPyExecutor: + """A mock PyExecutor class for testing request handling logic. + + This mock contains only the attributes and methods needed to test + the _handle_special_queue_items functionality. + """ + + def __init__(self, dist): + self.dist = dist + self.canceled_req_ids = [] + self.control_requests = [] + self.request_accumulated = [] + self.is_shutdown = False + self.expected_num_active_requests = 0 + self.new_active_requests_queue_latency_ms = 0.0 + self.waiting_queue = deque() + + def _handle_special_queue_items(self, new_requests): + """Handle special signals. + + This method mirrors PyExecutor._handle_special_queue_items. + """ + accepted_new_requests = [] + for idx, req_item in enumerate(new_requests): + if req_item.is_shutdown_request: + self.is_shutdown = True + break + elif req_item.is_canceled_request: + self.canceled_req_ids.append(req_item.id) + elif req_item.is_control_request: + self.control_requests.append(req_item) + if self.dist.rank == 0: + self.request_accumulated.extend(new_requests[idx + 1 :]) + break + else: + accepted_new_requests.append(req_item) + + return accepted_new_requests + + def update_waiting_queue(self): + """Update waiting queue to remove canceled requests. + + This method mirrors PyExecutor.update_waiting_queue. + """ + if self.canceled_req_ids: + canceled_set = set(self.canceled_req_ids) + self.waiting_queue = deque( + item for item in self.waiting_queue if item.id not in canceled_set + ) + + def clear_canceled_req_ids(self): + """Clear the list of canceled request IDs.""" + self.canceled_req_ids.clear() + + def get_canceled_req_ids(self): + """Get the list of canceled request IDs.""" + return self.canceled_req_ids + + def get_canceled_req_ids_size(self): + """Get the number of canceled request IDs.""" + return len(self.canceled_req_ids) + + def get_expected_num_active_requests(self): + """Get the expected number of active requests.""" + return self.expected_num_active_requests + + def get_waiting_queue_size(self): + """Get the size of the waiting queue.""" + return len(self.waiting_queue) + + def _get_new_active_requests_queue_latency(self): + """Get the queue latency for new active requests.""" + return self.new_active_requests_queue_latency_ms + + +@pytest.fixture +def mock_dist(): + """Create a mock Distributed instance for testing.""" + mock_dist = Mock() + mock_dist.rank = 0 + mock_dist.tp_size = 1 + return mock_dist + + +@pytest.fixture +def mock_executor(mock_dist): + """Create a MockPyExecutor instance for testing.""" + return MockPyExecutor(dist=mock_dist) + + +def test_handle_special_queue_items(mock_executor): + """Test special queue item handling.""" + # Create a mock request + mock_request = Mock() + if hasattr(mock_request, "sampling_config"): + delattr(mock_request, "sampling_config") + + normal_req = RequestQueueItem(1, mock_request) + cancel_req = RequestQueueItem(2, is_canceled_request=True) + shutdown_req = RequestQueueItem(SHUTDOWN_REQUEST_ID) + + requests = [normal_req, cancel_req, shutdown_req] + + valid_requests = mock_executor._handle_special_queue_items(requests) + + assert len(valid_requests) == 1 + assert valid_requests[0] == normal_req + assert mock_executor.is_shutdown + assert 2 in mock_executor.canceled_req_ids + + +def test_clear_canceled_req_ids(mock_executor): + """Test clearing canceled request IDs.""" + mock_executor.canceled_req_ids = [1, 2, 3] + assert len(mock_executor.canceled_req_ids) == 3 + + mock_executor.clear_canceled_req_ids() + + assert len(mock_executor.canceled_req_ids) == 0 + + +def test_update_waiting_queue(mock_executor): + """Test updating waiting queue to remove canceled requests.""" + items = [ + RequestQueueItem(1, Mock()), + RequestQueueItem(2, Mock()), + RequestQueueItem(3, Mock()), + ] + mock_executor.waiting_queue.extend(items) + mock_executor.canceled_req_ids = [2] + + mock_executor.update_waiting_queue() + + assert len(mock_executor.waiting_queue) == 2 + remaining_ids = [item.id for item in mock_executor.waiting_queue] + assert 1 in remaining_ids + assert 3 in remaining_ids + assert 2 not in remaining_ids + + +def test_getter_methods(mock_executor): + """Test various getter methods.""" + # Test initial values + assert mock_executor._get_new_active_requests_queue_latency() == 0 + assert mock_executor.get_expected_num_active_requests() == 0 + assert mock_executor.get_canceled_req_ids_size() == 0 + assert mock_executor.get_canceled_req_ids() == [] + assert mock_executor.get_waiting_queue_size() == 0 + + # Add some data and test + mock_executor.canceled_req_ids = [3, 4] + mock_executor.expected_num_active_requests = 5 + mock_executor.new_active_requests_queue_latency_ms = 10.5 + mock_executor.waiting_queue.append(RequestQueueItem(1, Mock())) + + assert mock_executor.get_canceled_req_ids_size() == 2 + assert mock_executor.get_canceled_req_ids() == [3, 4] + assert mock_executor.get_expected_num_active_requests() == 5 + assert mock_executor._get_new_active_requests_queue_latency() == 10.5 + assert mock_executor.get_waiting_queue_size() == 1 diff --git a/tests/unittest/_torch/executor/test_request_utils.py b/tests/unittest/_torch/executor/test_request_utils.py new file mode 100644 index 0000000000..ed209c6630 --- /dev/null +++ b/tests/unittest/_torch/executor/test_request_utils.py @@ -0,0 +1,1103 @@ +"""Tests for request_utils.py functions. + +This module tests: +- Request merging functions (merge_requests, merge_helix_requests) +- Attention DP scheduling functions (schedule_attention_dp_requests, balance_requests_across_ranks) +- Waiting queue functions (get_from_waiting_queue, can_process_attention_dp_request) +""" + +from collections import deque +from unittest.mock import Mock, patch + +import pytest + +from tensorrt_llm._torch.pyexecutor.executor_request_queue import RequestQueueItem +from tensorrt_llm._torch.pyexecutor.request_utils import ( + balance_requests_across_ranks, + can_process_attention_dp_request, + get_from_waiting_queue, + merge_helix_requests, + merge_requests, + schedule_attention_dp_requests, +) +from tensorrt_llm.bindings import executor as trtllm +from tensorrt_llm.mapping import CpType + + +@pytest.fixture +def attention_dp_config(): + """Create a config dict for attention DP testing.""" + return { + "tp_size": 4, + "max_num_active_requests": 8, + } + + +@pytest.fixture +def all_ranks_num_active_requests(): + return [2, 1, 3, 0] # 4 ranks + + +@pytest.fixture +def all_ranks_num_active_tokens(): + return [10, 5, 15, 8] # 4 ranks + + +def create_mock_request_with_py_schedule_params(attention_dp_rank=None, attention_dp_relax=False): + mock_request = Mock() + + if attention_dp_rank is not None: + mock_schedule_params = Mock() + mock_schedule_params.attention_dp_rank = attention_dp_rank + mock_schedule_params.attention_dp_relax = attention_dp_relax + + mock_schedule_params.configure_mock( + attention_dp_rank=attention_dp_rank, attention_dp_relax=attention_dp_relax + ) + + mock_request.py_scheduling_params = mock_schedule_params + else: + mock_request.py_scheduling_params = None + + mock_request.input_token_ids = [1, 2, 3] + + return mock_request + + +def append_to_waiting_queue(waiting_queue, rank, attention_dp_relax): + req_id = len(waiting_queue) + waiting_queue.append( + RequestQueueItem( + req_id, + create_mock_request_with_py_schedule_params( + attention_dp_rank=rank, attention_dp_relax=attention_dp_relax + ), + ) + ) + + +def test_merge_helix_requests_with_padding(): + """Test merge_helix_requests with basic valid input.""" + + tokens_per_block = 2 + + # Create request item with 13 tokens to get exactly 7 blocks for 4 CP ranks. + input_tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + executor_request = trtllm.Request( + input_token_ids=input_tokens, + max_tokens=5, + streaming=False, + sampling_config=trtllm.SamplingConfig(), + output_config=trtllm.OutputConfig(), + ) + request_item = RequestQueueItem( + id=1, + request=executor_request, + ) + + for rank in [0, 1, 2, 3]: + # Test merge_helix_requests with 4 CP ranks. + result = merge_helix_requests( + [request_item], + cp_rank=rank, + cp_size=4, + tokens_per_block=tokens_per_block, + exclude_last_generation_logits=False, + ) + + # Verify the result. + assert len(result) == 1 + llm_request = result[0] + from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest + + assert isinstance(llm_request, LlmRequest) + assert llm_request.request_id == 1 + if rank == 0: + assert llm_request.get_tokens(0) == [1, 2, 3, 4] + elif rank == 1: + assert llm_request.get_tokens(0) == [5, 6, 7, 8] + elif rank == 2: + assert llm_request.get_tokens(0) == [9, 10, 11, 12] + else: + assert llm_request.get_tokens(0) == [13] + + +def test_merge_helix_requests_without_padding(): + """Test merge_helix_requests with evenly divisible tokens (no padding).""" + + tokens_per_block = 4 + + # Create request item with 12 tokens to get exactly 3 blocks for 2 CP ranks. + input_tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] + executor_request = trtllm.Request( + input_token_ids=input_tokens, + max_tokens=5, + streaming=False, + sampling_config=trtllm.SamplingConfig(), + output_config=trtllm.OutputConfig(), + ) + request_item = RequestQueueItem( + id=1, + request=executor_request, + ) + + for rank in [0, 1]: + # Test merge_helix_requests with 2 CP ranks. + result = merge_helix_requests( + [request_item], + cp_rank=rank, + cp_size=2, + tokens_per_block=tokens_per_block, + exclude_last_generation_logits=False, + ) + + # Verify the result. + assert len(result) == 1 + llm_request = result[0] + from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest + + assert isinstance(llm_request, LlmRequest) + assert llm_request.request_id == 1 + if rank == 0: + assert llm_request.get_tokens(0) == [1, 2, 3, 4, 5, 6, 7, 8] + else: + assert llm_request.get_tokens(0) == [9, 10, 11, 12] + + +def test_merge_helix_requests_insufficient_blocks_error(): + """Test merge_helix_requests raises error when insufficient blocks.""" + tokens_per_block = 4 + + # Create input with only 12 tokens. This creates 3 blocks which is fewer than 4 CP ranks. + executor_request = trtllm.Request( + input_token_ids=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + max_tokens=12, + streaming=False, + sampling_config=trtllm.SamplingConfig(), + output_config=trtllm.OutputConfig(), + ) + request_item = RequestQueueItem( + id=1, + request=executor_request, + ) + + # Loop over ranks 0, 1, 2, 3 and verify that all ranks throw assertion. + for rank in range(4): + with pytest.raises( + ValueError, match="There aren't enough tokens to get at least one block per CP rank" + ): + merge_helix_requests( + [request_item], + cp_rank=rank, + cp_size=4, + tokens_per_block=tokens_per_block, + exclude_last_generation_logits=False, + ) + + +@patch("tensorrt_llm._torch.pyexecutor.request_utils.executor_request_to_llm_request") +def test_merge_requests_default(mock_convert): + """Test merging requests with default configuration.""" + mock_llm_request = Mock(child_requests=[]) + mock_convert.return_value = mock_llm_request + + requests = [RequestQueueItem(1, Mock()), RequestQueueItem(2, Mock())] + result = merge_requests( + requests, cp_config={}, cp_rank=0, cp_size=1, exclude_last_generation_logits=False + ) + + assert len(result) == 2 + assert mock_convert.call_count == 2 + + +def test_merge_requests_with_helix_cp_config(): + """Test merge_requests routes to merge_helix_requests with HELIX cp_config.""" + tokens_per_block = 2 + + # Create request item with 13 tokens to get exactly 7 blocks for 4 CP ranks. + input_tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + executor_request = trtllm.Request( + input_token_ids=input_tokens, + max_tokens=5, + streaming=False, + sampling_config=trtllm.SamplingConfig(), + output_config=trtllm.OutputConfig(), + ) + request_item = RequestQueueItem( + id=1, + request=executor_request, + ) + + cp_config = { + "cp_type": CpType.HELIX, + "tokens_per_block": tokens_per_block, + } + + for rank in [0, 1, 2, 3]: + # Test merge_requests with HELIX cp_config and 4 CP ranks. + result = merge_requests( + [request_item], + cp_config=cp_config, + cp_rank=rank, + cp_size=4, + exclude_last_generation_logits=False, + ) + + # Verify the result. + assert len(result) == 1 + llm_request = result[0] + from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest + + assert isinstance(llm_request, LlmRequest) + assert llm_request.request_id == 1 + if rank == 0: + assert llm_request.get_tokens(0) == [1, 2, 3, 4] + elif rank == 1: + assert llm_request.get_tokens(0) == [5, 6, 7, 8] + elif rank == 2: + assert llm_request.get_tokens(0) == [9, 10, 11, 12] + else: + assert llm_request.get_tokens(0) == [13] + + +def test_get_from_waiting_queue(): + """Test getting items from waiting queue.""" + # Add items to waiting queue + waiting_queue = deque() + items = [RequestQueueItem(i, Mock()) for i in range(5)] + waiting_queue.extend(items) + + # Get 3 items + result = get_from_waiting_queue( + waiting_queue, 3, enable_attention_dp=False, max_num_active_requests=16 + ) + + assert len(result) == 3 + assert result == items[:3] + assert len(waiting_queue) == 2 + + +@pytest.mark.parametrize( + "queue_size,request_count,expected_result,expected_remaining", + [ + (0, 5, 0, 0), # Empty queue + (3, -1, 0, 3), # Negative count + (3, 0, 0, 3), # Zero count + (3, 10, 3, 0), # Request more than available + ], +) +def test_get_from_waiting_queue_edge_cases( + queue_size, request_count, expected_result, expected_remaining +): + """Test edge cases for getting items from waiting queue.""" + # Setup queue + waiting_queue = deque() + if queue_size > 0: + items = [RequestQueueItem(i, Mock()) for i in range(queue_size)] + waiting_queue.extend(items) + + result = get_from_waiting_queue( + waiting_queue, request_count, enable_attention_dp=False, max_num_active_requests=16 + ) + + assert len(result) == expected_result + assert len(waiting_queue) == expected_remaining + + +def test_get_from_waiting_queue_with_attention_dp( + attention_dp_config, all_ranks_num_active_requests +): + waiting_queue = deque() + items = [RequestQueueItem(i, Mock()) for i in range(5)] + waiting_queue.extend(items) + + result = get_from_waiting_queue( + waiting_queue, + 3, + True, + attention_dp_config["max_num_active_requests"], + all_ranks_num_active_requests, + ) + + assert len(result) == 3 + assert result == items[:3] + assert len(waiting_queue) == 2 + + +def test_get_from_waiting_queue_with_attention_dp_filtering( + attention_dp_config, all_ranks_num_active_requests +): + req1 = RequestQueueItem( + 1, + create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=False), + ) + req2 = RequestQueueItem( + 2, create_mock_request_with_py_schedule_params(attention_dp_rank=1, attention_dp_relax=True) + ) + req3 = RequestQueueItem( + 3, create_mock_request_with_py_schedule_params(attention_dp_rank=None) + ) # No scheduling params + + waiting_queue = deque([req1, req2, req3]) + + # Set rank 0 to full capacity to test filtering + all_ranks_num_active_requests[0] = 8 + + result = get_from_waiting_queue( + waiting_queue, + 3, + True, + attention_dp_config["max_num_active_requests"], + all_ranks_num_active_requests, + ) + + assert len(result) == 2 + assert req2 in result + assert req3 in result + assert req1 not in result + + +def test_can_process_attention_dp_request(attention_dp_config): + max_num_active_requests = attention_dp_config["max_num_active_requests"] + + req_no_params = RequestQueueItem(1, Mock()) + assert can_process_attention_dp_request(req_no_params, [0, 0, 0, 0], max_num_active_requests) + + req_relax = RequestQueueItem( + 2, create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=True) + ) + assert can_process_attention_dp_request(req_relax, [0, 0, 0, 0], max_num_active_requests) + + req_target = RequestQueueItem( + 3, + create_mock_request_with_py_schedule_params(attention_dp_rank=1, attention_dp_relax=False), + ) + all_ranks = [0, 0, 0, 0] + assert can_process_attention_dp_request(req_target, all_ranks, max_num_active_requests) + assert all_ranks[1] == 1 + + req_no_capacity = RequestQueueItem( + 4, + create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=False), + ) + all_ranks_full = [8, 0, 0, 0] # Rank 0 is at capacity + assert not can_process_attention_dp_request( + req_no_capacity, all_ranks_full, max_num_active_requests + ) + + +def test_schedule_attention_dp_requests_scheduled_requests( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + req1 = RequestQueueItem( + 1, + create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=False), + ) + req2 = RequestQueueItem( + 2, + create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=False), + ) + + new_requests = [req1, req2] + + all_ranks_new_requests, _ = schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + result = all_ranks_new_requests[0] + + assert len(result) == 2 + assert req1 in result + assert req2 in result + + assert all_ranks_num_active_requests[0] == 4 + + +def test_schedule_attention_dp_requests_scheduled_requests_other_ranks( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + req1 = RequestQueueItem( + 1, + create_mock_request_with_py_schedule_params(attention_dp_rank=1, attention_dp_relax=False), + ) + req2 = RequestQueueItem( + 2, + create_mock_request_with_py_schedule_params(attention_dp_rank=2, attention_dp_relax=False), + ) + + new_requests = [req1, req2] + + all_ranks_new_requests, _ = schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + + result = all_ranks_new_requests[0] + assert len(result) == 0 + + assert all_ranks_num_active_requests[1] == 2 + assert all_ranks_num_active_requests[2] == 4 + + +def test_schedule_attention_dp_requests_unscheduled_requests( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + req1 = RequestQueueItem( + 1, create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=True) + ) + req2 = RequestQueueItem( + 2, create_mock_request_with_py_schedule_params(attention_dp_rank=1, attention_dp_relax=True) + ) + + new_requests = [req1, req2] + + all_ranks_new_requests, _ = schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + result = all_ranks_new_requests[0] + + assert len(result) == 1 # Only req1 for current rank + assert req1 in result + + +def test_schedule_attention_dp_requests_unscheduled_no_capacity( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + all_ranks_num_active_requests[0] = 8 + + req1 = RequestQueueItem( + 1, create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=True) + ) + + new_requests = [req1] + + all_ranks_new_requests, _ = schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + result = all_ranks_new_requests[0] + + assert len(result) == 0 # No capacity + + +def test_schedule_attention_dp_requests_mixed_scenarios( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + req_scheduled_current = RequestQueueItem( + 1, + create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=False), + ) + req_scheduled_other = RequestQueueItem( + 2, + create_mock_request_with_py_schedule_params(attention_dp_rank=1, attention_dp_relax=False), + ) + req_unscheduled_current = RequestQueueItem( + 3, create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=True) + ) + req_unscheduled_other = RequestQueueItem( + 4, create_mock_request_with_py_schedule_params(attention_dp_rank=2, attention_dp_relax=True) + ) + + new_requests = [ + req_scheduled_current, + req_scheduled_other, + req_unscheduled_current, + req_unscheduled_other, + ] + + all_ranks_new_requests, _ = schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + result = all_ranks_new_requests[0] + + assert len(result) == 2 + assert req_scheduled_current in result + assert req_unscheduled_current in result + + +def test_schedule_attention_dp_requests_empty_lists( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + all_ranks_new_requests, _ = schedule_attention_dp_requests( + [], + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + result = all_ranks_new_requests[0] + + assert len(result) == 0 + + +def test_schedule_attention_dp_requests_expected_num_active_calculation( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + req1 = RequestQueueItem( + 1, create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=True) + ) + req2 = RequestQueueItem( + 2, create_mock_request_with_py_schedule_params(attention_dp_rank=1, attention_dp_relax=True) + ) + + new_requests = [req1, req2] + + _, expected_num_active_requests = schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + + # 2 + 1 + 3 + 0 = 6, 6 + 2 = 8, (8 + 3) // 4 = 2, max(2, 2, 1, 3, 0) = 3 + # expected_num_active_requests = max((6 + 2 + 3) // 4, 3) = max(2, 3) = 3 + assert expected_num_active_requests == 3 + + +def test_schedule_attention_dp_requests_balance_requests_called( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + """Test that balance_requests_across_ranks is called with correct arguments.""" + req1 = RequestQueueItem( + 1, create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=True) + ) + + new_requests = [req1] + + with patch( + "tensorrt_llm._torch.pyexecutor.request_utils.balance_requests_across_ranks" + ) as mock_balance: + mock_balance.return_value = {0: [req1], 1: [], 2: [], 3: []} + + schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + + # Check that balance_requests_across_ranks was called + mock_balance.assert_called_once() + call_args = mock_balance.call_args[0] + assert isinstance(call_args[0], list) + assert isinstance(call_args[1], dict) + assert call_args[2] == all_ranks_num_active_requests # Third arg + assert call_args[3] == all_ranks_num_active_tokens # Fourth arg + + +def test_schedule_attention_dp_requests_no_scheduling_when_capacity_exceeded( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + all_ranks_num_active_requests[0] = 8 + + req1 = RequestQueueItem( + 1, + create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=False), + ) + + new_requests = [req1] + + all_ranks_new_requests, _ = schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + result = all_ranks_new_requests[0] + + assert len(result) == 0 # No requests scheduled + assert all_ranks_num_active_requests[0] == 8 # Capacity unchanged + + +def test_filter_and_schedule_integration( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + req_schedulable = RequestQueueItem( + 1, + create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=False), + ) + req_schedulable.request.input_token_ids = [1, 2, 3, 4] + req_relax = RequestQueueItem( + 2, create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=True) + ) + req_relax.request.input_token_ids = [1, 2] + + req_no_params = RequestQueueItem( + 3, create_mock_request_with_py_schedule_params(attention_dp_rank=None) + ) + + new_requests = [req_schedulable, req_relax, req_no_params] + + all_ranks_new_requests, _ = schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + result = all_ranks_new_requests[0] + + assert len(result) == 2 + assert req_schedulable in result + assert req_relax in result + + +def test_filter_and_schedule_with_capacity_limits( + attention_dp_config, all_ranks_num_active_requests, all_ranks_num_active_tokens +): + all_ranks_num_active_requests[0] = 7 + + req1 = RequestQueueItem( + 1, + create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=False), + ) + req1.request.input_token_ids = [1, 2, 3, 4] + req2 = RequestQueueItem( + 2, + create_mock_request_with_py_schedule_params(attention_dp_rank=0, attention_dp_relax=False), + ) + req2.request.input_token_ids = [1, 2, 3] + + new_requests = [req1, req2] + + all_ranks_new_requests, _ = schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + attention_dp_config["tp_size"], + attention_dp_config["max_num_active_requests"], + ) + result = all_ranks_new_requests[0] + + assert len(result) == 1 + assert req1 in result + + +def test_achieve_max_num_active_requests(attention_dp_config): + max_num_active_requests = attention_dp_config["max_num_active_requests"] + req_list = [] + req_id = 0 + for rank in range(4): + for _ in range(5): + req_list.append( + RequestQueueItem( + req_id, + create_mock_request_with_py_schedule_params( + attention_dp_rank=rank, attention_dp_relax=False + ), + ) + ) + req_id += 1 + req_list.append( + RequestQueueItem( + req_id, + create_mock_request_with_py_schedule_params( + attention_dp_rank=rank, attention_dp_relax=True + ), + ) + ) + req_id += 1 + + all_ranks_num_active_requests = [5, 6, 3, 7] + waiting_queue = deque(req_list) + available_active_requests = max_num_active_requests * 4 - sum(all_ranks_num_active_requests) + + result = get_from_waiting_queue( + waiting_queue, + available_active_requests, + True, + max_num_active_requests, + all_ranks_num_active_requests, + ) + + assert len(result) == available_active_requests + + +@pytest.mark.parametrize( + "max_num_active_requests,all_ranks_num_active_requests,request_configs,all_ranks_expected_req_ids", + [ + # Case: Balanced distribution of relaxed requests + ( + 3, + [0, 0, 0, 0], + [(None, True)] * 7, + { + 0: [0, 1], # First 2 requests go to rank 0 + 1: [2, 3], # Next 2 requests go to rank 1 + 2: [4, 5], # Next 2 requests go to rank 2 + 3: [6], # Last request goes to rank 3 + }, + ), + # Case: Balanced distribution of relaxed requests with existing load + ( + 3, + [1, 2, 3, 0], + [(None, True)] * 13, + { + 0: [0, 1], # Rank 0 gets first 2 requests + 1: [2], # Rank 1 gets 1 request (already has 2) + 2: [], # Rank 2 is at capacity (3) + 3: [3, 4, 5], # Rank 3 gets 3 requests (starts with 0) + }, + ), + # Case: Limited by max active + ( + 3, + [0, 0, 0, 0], + [(None, True)] * 13, + { + 0: [0, 1, 3], # First 3 requests (0, 1, 3) + 1: [2, 4, 6], # Next 3 requests (2, 4, 6) + 2: [5, 7, 9], # Next 3 requests (5, 7, 9) + 3: [8, 10, 11], # Last 3 requests (8, 10, 11) + }, + ), + # Case: Empty new requests + (3, [3, 3, 3, 0], [], {0: [], 1: [], 2: [], 3: []}), + # Case: Rank 0 is full and cannot schedule attention_dp rank request + ( + 3, + [3, 1, 3, 0], + [(0, False), (0, True)], + { + 0: [], # Rank 0 is full + 1: [1], # Rank 1 gets the relaxed request (req1) + 2: [], # No relaxed requests assigned here + 3: [], # No relaxed requests assigned here + }, + ), + # Case: Only room for 1 request, need to skip req0 with attention dp rank + ( + 3, + [3, 2, 3, 3], + [(0, False), (0, True)], + { + 0: [], # Rank 0 is full + 1: [1], # Rank 1 gets the relaxed request + 2: [], # Rank 2 is at capacity + 3: [], # Rank 3 is at capacity + }, + ), + # Case: Targeting ranks 1 and 3 that have room + ( + 3, + [2, 1, 3, 0], + [(1, False), (3, False)], + { + 0: [], # No requests assigned to rank 0 + 1: [0], # Request 0 targets rank 1 + 2: [], # No requests assigned to rank 2 + 3: [1], # Request 1 targets rank 3 + }, + ), + # Case: Target dp rank specified, but relax is True + ( + 3, + [3, 3, 3, 1], + [(0, True), (1, True), (2, True)], + { + 0: [], # Rank 0 is at capacity + 1: [], # Rank 1 is at capacity + 2: [], # Rank 2 is at capacity + 3: [0, 1], # Rank 3 gets both relaxed requests + }, + ), + # Case: Mixed targeting and relaxed + ( + 3, + [3, 3, 3, 0], + [(0, False), (1, True), (3, False)], + { + 0: [], # Rank 0 is at capacity + 1: [], # Rank 1 is at capacity + 2: [], # Rank 2 is at capacity + 3: [2, 1], # Rank 3 gets both requests (targeted + relaxed) + }, + ), + ], +) +def test_attention_dp_scheduling_cases( + max_num_active_requests, + all_ranks_num_active_requests, + request_configs, + all_ranks_expected_req_ids, +): + """Test attention DP scheduling with various scenarios.""" + waiting_queue = deque() + for rank, relax in request_configs: + append_to_waiting_queue(waiting_queue, rank, relax) + + run_test_attention_dp_scheduling( + max_num_active_requests, + waiting_queue, + all_ranks_num_active_requests, + all_ranks_expected_req_ids, + ) + + +def run_test_attention_dp_scheduling( + max_num_active_requests, + waiting_queue, + all_ranks_num_active_requests, + all_ranks_expected_req_ids, +): + num_ranks = len(all_ranks_num_active_requests) + total_num_active_requests = sum(all_ranks_num_active_requests) + total_max_num_active_requests = max_num_active_requests * num_ranks + enable_attention_dp = True + + new_requests = get_from_waiting_queue( + waiting_queue, + total_max_num_active_requests - total_num_active_requests, + enable_attention_dp, + max_num_active_requests, + all_ranks_num_active_requests, + ) + + # Create mock token counts for testing + all_ranks_num_active_tokens = [10 + i * 5 for i in range(num_ranks)] + + # Schedule attention dp requests + all_ranks_new_requests, _ = schedule_attention_dp_requests( + new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + num_ranks, + max_num_active_requests, + ) + + assert len(all_ranks_new_requests) == num_ranks + print("all_ranks_new_requests:", all_ranks_new_requests) + for rank, reqs in all_ranks_new_requests.items(): + req_ids = [req.id for req in reqs] + assert req_ids == all_ranks_expected_req_ids[rank] + + +def test_balance_requests_across_ranks_empty_requests(): + """Test balance_requests_across_ranks with empty requests list.""" + all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} + all_ranks_num_active_requests = [2, 1, 3, 0] + all_ranks_num_active_tokens = [20, 10, 30, 5] + expected_num_active_requests = 3 + + result = balance_requests_across_ranks( + [], + all_ranks_new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + expected_num_active_requests, + ) + + # Should return the original structure unchanged + assert result == all_ranks_new_requests + for rank in range(4): + assert len(result[rank]) == 0 + + +def test_balance_requests_across_ranks_single_request(): + """Test balance_requests_across_ranks with a single request.""" + req = RequestQueueItem(1, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) + req.request.input_token_ids = [1, 2, 3, 4, 5] # 5 tokens + + all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} + all_ranks_num_active_requests = [1, 2, 0, 1] # Rank 2 has lowest count + all_ranks_num_active_tokens = [10, 20, 5, 15] + expected_num_active_requests = 2 + + result = balance_requests_across_ranks( + [req], + all_ranks_new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + expected_num_active_requests, + ) + + # Request should be assigned to rank 2 (lowest active count) + assert len(result[0]) == 0 + assert len(result[1]) == 0 + assert len(result[2]) == 1 + assert len(result[3]) == 0 + assert result[2][0] == req + + +def test_balance_requests_across_ranks_multiple_requests(): + """Test balance_requests_across_ranks with multiple requests.""" + # Create requests with different token counts + req1 = RequestQueueItem(1, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) + req1.request.input_token_ids = [1, 2, 3] # 3 tokens + + req2 = RequestQueueItem(2, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) + req2.request.input_token_ids = [1, 2, 3, 4, 5, 6] # 6 tokens + + req3 = RequestQueueItem(3, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) + req3.request.input_token_ids = [1, 2] # 2 tokens + + all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} + all_ranks_num_active_requests = [0, 1, 2, 1] + all_ranks_num_active_tokens = [5, 15, 25, 10] + expected_num_active_requests = 2 + + result = balance_requests_across_ranks( + [req1, req2, req3], + all_ranks_new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + expected_num_active_requests, + ) + + # Requests should be distributed based on heap (lowest active count first) + # Requests are sorted by token count (descending) first, then assigned to ranks with lowest active count + # req2 (6 tokens) -> rank 0 (0 active) -> total: 1 active, 11 tokens + # req3 (2 tokens) -> rank 0 (1 active) -> total: 2 active, 13 tokens (rank 0 still has capacity) + # req1 (3 tokens) -> rank 3 (1 active) -> total: 2 active, 13 tokens + # Rank 1: 1 active, gets nothing (rank 0 took 2 requests) + # Rank 2: 2 active, gets nothing (at capacity) + + assert len(result[0]) == 2 # req2 and req3 (rank 0 has capacity for 2) + assert len(result[1]) == 0 # no requests (rank 0 took 2 requests) + assert len(result[2]) == 0 # at capacity + assert len(result[3]) == 1 # req1 + + # Verify the requests are assigned correctly + assert result[0][0] == req2 # First request (highest token count) + assert result[0][1] == req3 # Second request + assert result[3][0] == req1 + + +def test_balance_requests_across_ranks_capacity_limits(): + """Test balance_requests_across_ranks respects capacity limits.""" + # Create multiple requests + requests = [] + for i in range(4): + req = RequestQueueItem( + i, create_mock_request_with_py_schedule_params(attention_dp_rank=None) + ) + req.request.input_token_ids = [1] * (i + 1) # Variable token counts + requests.append(req) + + all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} + all_ranks_num_active_requests = [1, 1, 1, 1] # All ranks start with 1 + all_ranks_num_active_tokens = [10, 10, 10, 10] + expected_num_active_requests = 2 + + result = balance_requests_across_ranks( + requests, + all_ranks_new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + expected_num_active_requests, + ) + + # Each rank can only take 1 more request (1 + 1 = 2, which equals expected_num_active_requests) + total_assigned = sum(len(rank_requests) for rank_requests in result.values()) + assert total_assigned == 4 # 4 ranks with 1 additional request each + + # Verify no rank exceeds capacity + for rank in range(4): + assert len(result[rank]) <= 1 + + +def test_balance_requests_across_ranks_heap_ordering(): + """Test that balance_requests_across_ranks uses heap ordering correctly.""" + # Create requests with same token count to test heap ordering + req1 = RequestQueueItem(1, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) + req1.request.input_token_ids = [1, 2, 3] # 3 tokens + + req2 = RequestQueueItem(2, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) + req2.request.input_token_ids = [1, 2, 3] # 3 tokens + + req3 = RequestQueueItem(3, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) + req3.request.input_token_ids = [1, 2, 3] # 3 tokens + + all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} + # Rank 0 has highest active count, should get requests last + all_ranks_num_active_requests = [3, 1, 0, 2] + all_ranks_num_active_tokens = [30, 10, 5, 20] + expected_num_active_requests = 4 + + result = balance_requests_across_ranks( + [req1, req2, req3], + all_ranks_new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + expected_num_active_requests, + ) + + # Requests should be assigned in order of lowest active count first + # Since all requests have same token count, they're assigned based on active count order + # Rank 2: 0 active -> gets req1 and req2 (has capacity for 2) + # Rank 1: 1 active -> gets req3 (after rank 2 takes 2) + # Rank 3: 2 active -> gets nothing (rank 1 took req3) + # Rank 0: 3 active -> gets nothing (at capacity) + + assert len(result[0]) == 0 # at capacity + assert len(result[1]) == 1 # req3 + assert len(result[2]) == 2 # req1 and req2 + assert len(result[3]) == 0 # no requests + + # Verify the requests are assigned correctly + assert result[1][0] == req3 # Third request + assert result[2][0] == req1 # First request + assert result[2][1] == req2 # Second request + + +def test_balance_requests_across_ranks_token_count_sorting(): + """Test that requests are sorted by token count before distribution.""" + # Create requests with different token counts + req1 = RequestQueueItem(1, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) + req1.request.input_token_ids = [1] # 1 token (smallest) + + req2 = RequestQueueItem(2, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) + req2.request.input_token_ids = [1, 2, 3, 4, 5] # 5 tokens (largest) + + req3 = RequestQueueItem(3, create_mock_request_with_py_schedule_params(attention_dp_rank=None)) + req3.request.input_token_ids = [1, 2, 3] # 3 tokens (medium) + + all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []} + all_ranks_num_active_requests = [0, 0, 0, 0] # All ranks start empty + all_ranks_num_active_tokens = [5, 5, 5, 5] + expected_num_active_requests = 2 + + result = balance_requests_across_ranks( + [req1, req2, req3], + all_ranks_new_requests, + all_ranks_num_active_requests, + all_ranks_num_active_tokens, + expected_num_active_requests, + ) + + # Requests should be sorted by token count (descending) before distribution + # Then assigned to ranks with lowest active count first + # req2 (5 tokens) -> rank 0 (0 active) + # req3 (3 tokens) -> rank 1 (0 active) + # req1 (1 token) -> rank 2 (0 active) + + assert len(result[0]) == 1 # req2 (highest token count) + assert len(result[1]) == 1 # req3 + assert len(result[2]) == 1 # req1 (lowest token count) + assert len(result[3]) == 0 + + # Verify the requests are assigned correctly + assert result[0][0] == req2 + assert result[1][0] == req3 + assert result[2][0] == req1