diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index d5a3e04054..c2ab5fcc81 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -11,12 +11,13 @@ from tensorrt_llm._torch.models.modeling_utils import \ from tensorrt_llm._utils import (confidential_compute_enabled, str_dtype_to_binding, torch_dtype_to_str) from tensorrt_llm.bindings.executor import DecodingMode -from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig, - EagleDecodingConfig, KvCacheConfig, - MTPDecodingConfig, PeftCacheConfig, - SamplerType, SchedulerConfig, - SparseAttentionConfig, - SpeculativeConfig, TorchLlmArgs) + +# isort: off +from tensorrt_llm.llmapi.llm_args import ( + CacheTransceiverConfig, EagleDecodingConfig, KvCacheConfig, + MTPDecodingConfig, PeftCacheConfig, SamplerType, SchedulerConfig, + SparseAttentionConfig, SpeculativeConfig, TorchLlmArgs, WaitingQueuePolicy) +# isort: on from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import (LoraConfig, get_default_trtllm_modules_to_hf_modules) @@ -1006,6 +1007,9 @@ def create_py_executor_instance( kv_cache_transceiver = create_kv_cache_transceiver( mapping, dist, kv_cache_manager, attention_type, cache_transceiver_config) + waiting_queue_policy = (scheduler_config.waiting_queue_policy + if scheduler_config is not None else + WaitingQueuePolicy.FCFS) return PyExecutor( resource_manager, scheduler, @@ -1029,7 +1033,8 @@ def create_py_executor_instance( max_seq_len=max_seq_len, peft_cache_config=peft_cache_config, virtual_memory_pools=virtual_memory_pools, - execution_stream=execution_stream) + execution_stream=execution_stream, + waiting_queue_policy=waiting_queue_policy) def create_torch_sampler_args( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 00bfeb3304..e8128fed95 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -5,7 +5,6 @@ import os import threading import time import traceback -from collections import deque from contextlib import contextmanager from enum import IntEnum from queue import Queue @@ -32,7 +31,7 @@ from tensorrt_llm.bindings.executor import (DisServingRequestStats, StaticBatchingStats) from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType, ReqIdsSet) -from tensorrt_llm.llmapi.llm_args import PeftCacheConfig +from tensorrt_llm.llmapi.llm_args import PeftCacheConfig, WaitingQueuePolicy from tensorrt_llm.logger import logger from tensorrt_llm.mapping import CpType from tensorrt_llm.runtime.generation import CUASSERT @@ -63,7 +62,8 @@ from .resource_manager import (ResourceManager, ResourceManagerType, from .sampler import (AsyncWorkerMixin, Sampler, SamplerEvent, SampleState, SampleStateTensors, TRTLLMSampler) from .scheduler import (RequestScheduler, ScheduledRequests, - SerializableSchedulerOutput) + SerializableSchedulerOutput, WaitingQueue, + create_waiting_queue) # Environment variable to specify iteration ranges for profiling start/stop. # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." @@ -253,30 +253,32 @@ class PyExecutor: # 1024 in-flight micro batches can avoid synchronization in most cases and keep host memory usage low. MIN_ASYNC_MICRO_BATCH_NUM = 1024 - def __init__(self, - resource_manager, - scheduler: RequestScheduler, - model_engine: ModelEngine, - sampler: Sampler, - dist: Distributed, - max_num_sequences: int, - drafter: Optional[Drafter] = None, - disable_overlap_scheduler: bool = False, - max_input_len: int = 0x7fffffff, - max_batch_size: int = 8, - max_beam_width: int = 1, - max_draft_len: int = 0, - max_total_draft_tokens: int = 0, - kv_cache_transceiver: Optional[KvCacheTransceiver] = None, - guided_decoder: Optional[GuidedDecoder] = None, - garbage_collection_gen0_threshold: Optional[int] = None, - start_worker: bool = True, - kv_connector_manager: Optional[KvCacheConnectorManager] = None, - max_seq_len: Optional[int] = None, - peft_cache_config: Optional[PeftCacheConfig] = None, - virtual_memory_pools: Optional[dict] = None, - hang_detection_timeout: Optional[int] = None, - execution_stream: Optional[torch.cuda.Stream] = None): + def __init__( + self, + resource_manager, + scheduler: RequestScheduler, + model_engine: ModelEngine, + sampler: Sampler, + dist: Distributed, + max_num_sequences: int, + drafter: Optional[Drafter] = None, + disable_overlap_scheduler: bool = False, + max_input_len: int = 0x7fffffff, + max_batch_size: int = 8, + max_beam_width: int = 1, + max_draft_len: int = 0, + max_total_draft_tokens: int = 0, + kv_cache_transceiver: Optional[KvCacheTransceiver] = None, + guided_decoder: Optional[GuidedDecoder] = None, + garbage_collection_gen0_threshold: Optional[int] = None, + start_worker: bool = True, + kv_connector_manager: Optional[KvCacheConnectorManager] = None, + max_seq_len: Optional[int] = None, + peft_cache_config: Optional[PeftCacheConfig] = None, + virtual_memory_pools: Optional[dict] = None, + hang_detection_timeout: Optional[int] = None, + execution_stream: Optional[torch.cuda.Stream] = None, + waiting_queue_policy: WaitingQueuePolicy = WaitingQueuePolicy.FCFS): super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = dist.rank @@ -474,7 +476,8 @@ class PyExecutor: self.hang_detector) # Waiting queue for requests that have been fetched but not yet scheduled - self.waiting_queue: deque[RequestQueueItem] = deque() + self.waiting_queue: WaitingQueue = create_waiting_queue( + waiting_queue_policy) self.control_request_barrier = threading.Event() self.control_action_done = threading.Event() @@ -2233,8 +2236,7 @@ class PyExecutor: self.model_engine.model.lm_head.num_embeddings): raise ValueError("Token ID out of range") - def _fetch_and_enqueue_requests(self, - waiting_queue: deque[RequestQueueItem], + def _fetch_and_enqueue_requests(self, waiting_queue: WaitingQueue, total_num_active_requests: int) -> None: """Fetch requests from request_queue and enqueue to waiting_queue.""" # Block new requests while control requests are pending @@ -2277,11 +2279,11 @@ class PyExecutor: > 1) and self.dist.rank > 0: attach_py_objects_to_requests(new_requests, py_request_objects) - waiting_queue.extend(new_requests) + waiting_queue.add_requests(new_requests) def _pop_from_waiting_queue( self, - waiting_queue: deque[RequestQueueItem], + waiting_queue: WaitingQueue, total_num_active_requests: int, all_ranks_num_active_requests: Optional[List[int]] = None ) -> List[RequestQueueItem]: @@ -2302,7 +2304,7 @@ class PyExecutor: @nvtx_range("_fetch_new_requests") def _fetch_new_requests( - self, waiting_queue: deque[RequestQueueItem], + self, waiting_queue: WaitingQueue, activate_requests: List[LlmRequest]) -> List[LlmRequest]: """Fetch new requests and return LlmRequests ready for execution.""" # 1. Gather info and calculate total_num_active_requests @@ -3039,8 +3041,7 @@ class PyExecutor: canceled_req_ids_set = set(self.canceled_req_ids) # Remove canceled requests from the waiting queue - self.waiting_queue = deque(req for req in self.waiting_queue - if req.id not in canceled_req_ids_set) + self.waiting_queue.remove_by_ids(canceled_req_ids_set) still_pending_canceled_ids = [] for request in self.active_requests: diff --git a/tensorrt_llm/_torch/pyexecutor/request_utils.py b/tensorrt_llm/_torch/pyexecutor/request_utils.py index 42018b982a..d06658e93f 100644 --- a/tensorrt_llm/_torch/pyexecutor/request_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/request_utils.py @@ -2,8 +2,11 @@ import heapq import os -from collections import deque, namedtuple -from typing import Any, Dict, List, Optional, Tuple +from collections import namedtuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +if TYPE_CHECKING: + from .scheduler import WaitingQueue import torch @@ -246,7 +249,7 @@ def can_process_attention_dp_request( def get_from_waiting_queue( - waiting_queue: deque, + waiting_queue: "WaitingQueue", max_req_count: int, enable_attention_dp: bool, max_num_active_requests: int, @@ -277,11 +280,11 @@ def get_from_waiting_queue( ) while req_count < max_req_count and waiting_queue: - req_item = waiting_queue[0] + req_item = waiting_queue.peek_request() num_children = len(req_item.child_req_ids) if req_item.child_req_ids else 0 if (req_count + 1 + num_children) > max_req_count: break - req_item = waiting_queue.popleft() + req_item = waiting_queue.pop_request() can_process = ( can_process_attention_dp_request( @@ -299,7 +302,7 @@ def get_from_waiting_queue( # Put the pending requests back to the waiting queue # All ranks should have the same waiting queue - waiting_queue.extendleft(reversed(pending_requests)) + waiting_queue.prepend_requests(reversed(pending_requests)) return items diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py b/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py new file mode 100644 index 0000000000..a3d63bc24c --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Scheduler module for TensorRT-LLM PyExecutor. + +This module contains: +- Request schedulers (capacity, micro-batch, unified) +- Waiting queues (FCFS) +""" + +# Re-export from scheduler.py +from .scheduler import ( + BindCapacityScheduler, + BindMicroBatchScheduler, + CapacityScheduler, + KVCacheV2DummyScheduler, + MicroBatchScheduler, + PyCapacityScheduler, + PyMicroBatchScheduler, + RequestList, + RequestScheduler, + ScheduledRequests, + SchedulerOutput, + SerializableSchedulerOutput, + SimpleScheduler, + SimpleUnifiedScheduler, +) + +# Re-export from waiting_queue.py +from .waiting_queue import FCFSWaitingQueue, WaitingQueue, create_waiting_queue + +__all__ = [ + # Schedulers + "BindCapacityScheduler", + "BindMicroBatchScheduler", + "CapacityScheduler", + "KVCacheV2DummyScheduler", + "MicroBatchScheduler", + "PyCapacityScheduler", + "PyMicroBatchScheduler", + "RequestList", + "RequestScheduler", + "ScheduledRequests", + "SchedulerOutput", + "SerializableSchedulerOutput", + "SimpleScheduler", + "SimpleUnifiedScheduler", + # Waiting queues + "FCFSWaitingQueue", + "WaitingQueue", + "create_waiting_queue", +] diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py similarity index 77% rename from tensorrt_llm/_torch/pyexecutor/scheduler.py rename to tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py index 6631057251..7ca7168aa0 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import namedtuple from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import Optional, Set from strenum import StrEnum @@ -12,14 +12,20 @@ from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy from tensorrt_llm.logger import logger # Assuming these imports exist in your environment -from .llm_request import LlmRequest, LlmRequestState +from ..llm_request import LlmRequest, LlmRequestState RequestList = list[LlmRequest] -SchedulerOutput = namedtuple("SchedulerOutput", [ - "context_requests", "generation_requests", "paused_requests", - "fitting_disagg_gen_init_requests", "num_fitting_requests" -]) +SchedulerOutput = namedtuple( + "SchedulerOutput", + [ + "context_requests", + "generation_requests", + "paused_requests", + "fitting_disagg_gen_init_requests", + "num_fitting_requests", + ], +) class ScheduledRequests: @@ -31,12 +37,13 @@ class ScheduledRequests: @property def is_generation_only(self) -> bool: - return (not self.context_requests and all( - len(req.draft_tokens) == 0 for req in self.generation_requests)) + return not self.context_requests and all( + len(req.draft_tokens) == 0 for req in self.generation_requests + ) @property def can_run_cuda_graph(self) -> bool: - return (not self.context_requests) + return not self.context_requests @property def batch_size(self) -> int: @@ -47,10 +54,10 @@ class ScheduledRequests: class RequestScheduler(ABC): - @abstractmethod - def schedule_request(self, active_requests: RequestList, - inflight_request_ids: set[int]) -> SchedulerOutput: + def schedule_request( + self, active_requests: RequestList, inflight_request_ids: set[int] + ) -> SchedulerOutput: """ :param active_requests: list of active requests, up to maximum number of sequences :param inflight_request_ids: set of request ids that are inflight (of all micro batches) @@ -72,36 +79,34 @@ class RequestScheduler(ABC): @dataclass class SerializableSchedulerOutput: """ - Serializable version of SchedulerOutput, used for sending schedule result to other ranks. Need this class because LlmRequest is not serializable by pickle. + Serializable version of SchedulerOutput, used for sending schedule result to other ranks. + Need this class because LlmRequest is not serializable by pickle. """ + context_requests: list[int] # request ids of context requests generation_requests: list[int] # request ids of generation requests paused_requests: list[int] # request ids of paused requests fitting_disagg_gen_init_requests: list[ - int] # request ids of fitting disaggregated generation initialization requests + int + ] # request ids of fitting disaggregated generation initialization requests num_fitting_requests: int # number of fitting requests @classmethod def from_scheduler_result( - cls, scheduled_requests: ScheduledRequests, - fitting_disagg_gen_init_requests: RequestList, - num_fitting_requests: int) -> "SerializableSchedulerOutput": - return cls(context_requests=[ - req.request_id for req in scheduled_requests.context_requests - ], - generation_requests=[ - req.request_id - for req in scheduled_requests.generation_requests - ], - paused_requests=[ - req.request_id - for req in scheduled_requests.paused_requests - ], - fitting_disagg_gen_init_requests=[ - req.request_id - for req in fitting_disagg_gen_init_requests - ], - num_fitting_requests=num_fitting_requests) + cls, + scheduled_requests: ScheduledRequests, + fitting_disagg_gen_init_requests: RequestList, + num_fitting_requests: int, + ) -> "SerializableSchedulerOutput": + return cls( + context_requests=[req.request_id for req in scheduled_requests.context_requests], + generation_requests=[req.request_id for req in scheduled_requests.generation_requests], + paused_requests=[req.request_id for req in scheduled_requests.paused_requests], + fitting_disagg_gen_init_requests=[ + req.request_id for req in fitting_disagg_gen_init_requests + ], + num_fitting_requests=num_fitting_requests, + ) def to_scheduler_result( self, active_requests: RequestList @@ -118,14 +123,12 @@ class SerializableSchedulerOutput: id_to_request[req_id] for req_id in self.paused_requests ] fitting_disagg_gen_init_requests = [ - id_to_request[req_id] - for req_id in self.fitting_disagg_gen_init_requests + id_to_request[req_id] for req_id in self.fitting_disagg_gen_init_requests ] return scheduled_requests, fitting_disagg_gen_init_requests, self.num_fitting_requests class CapacityScheduler(ABC): - @abstractmethod def schedule_request( self, active_requests: RequestList @@ -139,14 +142,12 @@ class CapacityScheduler(ABC): class BindCapacityScheduler(CapacityScheduler): - def __init__( self, max_num_requests: int, kv_cache_manager, peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None, - scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. - GUARANTEED_NO_EVICT, + scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, two_step_lookahead: bool = False, ): super(BindCapacityScheduler, self).__init__() @@ -159,13 +160,13 @@ class BindCapacityScheduler(CapacityScheduler): has_kv_cache_manager=kv_cache_manager is not None, two_step_lookahead=two_step_lookahead, no_schedule_until_state=LlmRequestState.CONTEXT_INIT, - no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE) + no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE, + ) def schedule_request( self, active_requests: RequestList ) -> tuple[list[LlmRequest], list[LlmRequest], list[LlmRequest]]: - return self.impl(active_requests, self.kv_cache_manager, - self.peft_cache_manager) + return self.impl(active_requests, self.kv_cache_manager, self.peft_cache_manager) class KVCacheV2DummyScheduler(CapacityScheduler): @@ -190,21 +191,22 @@ class KVCacheV2DummyScheduler(CapacityScheduler): req_state = request.state # if request cannot be scheduled yet or request should no longer be scheduled, skip if not req_state == LlmRequestState.DISAGG_GENERATION_INIT and ( - req_state.value < self.no_schedule_until_state.value - or req_state.value >= self.no_schedule_after_state.value): + req_state.value < self.no_schedule_until_state.value + or req_state.value >= self.no_schedule_after_state.value + ): continue - if len(scheduled_requests - ) >= self.max_num_requests or reserved_blocks >= max_blocks: + if len(scheduled_requests) >= self.max_num_requests or reserved_blocks >= max_blocks: break - elif req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE: + elif ( + req_state == LlmRequestState.GENERATION_IN_PROGRESS + or req_state == LlmRequestState.GENERATION_TO_COMPLETE + ): scheduled_requests.append(request) - reserved_blocks += self.kv_cache_manager.get_needed_resource_to_completion( - request) + reserved_blocks += self.kv_cache_manager.get_needed_resource_to_completion(request) elif req_state == LlmRequestState.DISAGG_GENERATION_INIT: scheduled_disagg_gen_init_requests.append(request) - reserved_blocks += self.kv_cache_manager.get_needed_resource_to_completion( - request) + reserved_blocks += self.kv_cache_manager.get_needed_resource_to_completion(request) else: pending_requests.append(request) @@ -214,8 +216,7 @@ class KVCacheV2DummyScheduler(CapacityScheduler): if len(scheduled_requests) >= self.max_num_requests: break elif req_state == LlmRequestState.CONTEXT_INIT: - needed_blocks = self.kv_cache_manager.get_needed_resource_to_completion( - request) + needed_blocks = self.kv_cache_manager.get_needed_resource_to_completion(request) if needed_blocks <= avaiable_blocks: scheduled_requests.append(request) avaiable_blocks -= needed_blocks @@ -223,15 +224,14 @@ class KVCacheV2DummyScheduler(CapacityScheduler): # If one requests fails to be scheduled, break break - assert len(scheduled_requests) + len( - scheduled_disagg_gen_init_requests) > 0, ( - "no pending request can get enough resource to complete, " - "please increase KV cache pool size.") + assert len(scheduled_requests) + len(scheduled_disagg_gen_init_requests) > 0, ( + "no pending request can get enough resource to complete, " + "please increase KV cache pool size." + ) return scheduled_requests, scheduled_disagg_gen_init_requests, [] class MicroBatchScheduler(ABC): - @abstractmethod def schedule( self, active_requests: RequestList, inflight_request_ids: set[int] @@ -241,12 +241,12 @@ class MicroBatchScheduler(ABC): :param inflight_request_ids: set of request ids that are inflight (of all micro batches) :return: (contextRequests, generationRequests) """ - # to be aligned with MicroBatchScheduler::scheduleRequests in cpp/tensorrt_llm/batch_manager/microBatchScheduler.h + # to be aligned with MicroBatchScheduler::scheduleRequests + # in cpp/tensorrt_llm/batch_manager/microBatchScheduler.h raise NotImplementedError class BindMicroBatchScheduler(MicroBatchScheduler): - def __init__( self, max_batch_size: int, @@ -260,43 +260,49 @@ class BindMicroBatchScheduler(MicroBatchScheduler): ctx_chunk_config_cpp = None if ctx_chunk_config is not None: ctx_chunk_config_cpp = tb_internal.batch_manager.ContextChunkingConfig( - ctx_chunk_config[0]._to_pybind(), ctx_chunk_config[1]) + ctx_chunk_config[0]._to_pybind(), ctx_chunk_config[1] + ) - self.impl = tb_internal.algorithms.MicroBatchScheduler( - ctx_chunk_config_cpp, max_num_tokens) + self.impl = tb_internal.algorithms.MicroBatchScheduler(ctx_chunk_config_cpp, max_num_tokens) def schedule( self, active_requests: RequestList, inflight_request_ids: set[int] ) -> tuple[list[LlmRequest], list[LlmRequest]]: - return self.impl(active_requests, inflight_request_ids, - self.max_batch_size, self.max_num_tokens) + return self.impl( + active_requests, inflight_request_ids, self.max_batch_size, self.max_num_tokens + ) class SimpleScheduler(RequestScheduler): - - def __init__(self, capacity_scheduler: CapacityScheduler, - micro_batch_scheduler: MicroBatchScheduler): + def __init__( + self, capacity_scheduler: CapacityScheduler, micro_batch_scheduler: MicroBatchScheduler + ): super(SimpleScheduler, self).__init__() self.capacity_scheduler = capacity_scheduler self.micro_batch_scheduler = micro_batch_scheduler - def schedule_request(self, active_requests: RequestList, - inflight_request_ids: set[int]) -> SchedulerOutput: - fitting_requests, fitting_disagg_gen_init_requests, paused_requests = self.capacity_scheduler.schedule_request( - active_requests) + def schedule_request( + self, active_requests: RequestList, inflight_request_ids: set[int] + ) -> SchedulerOutput: + fitting_requests, fitting_disagg_gen_init_requests, paused_requests = ( + self.capacity_scheduler.schedule_request(active_requests) + ) context_requests, generation_requests = self.micro_batch_scheduler.schedule( - fitting_requests, inflight_request_ids) + fitting_requests, inflight_request_ids + ) # Convert from binding type RequestVector to list[LlmRequest], # so Python fields on LlmRequest won't be stripped away - return SchedulerOutput(list(context_requests), - list(generation_requests), list(paused_requests), - list(fitting_disagg_gen_init_requests), - len(fitting_requests)) + return SchedulerOutput( + list(context_requests), + list(generation_requests), + list(paused_requests), + list(fitting_disagg_gen_init_requests), + len(fitting_requests), + ) def can_schedule(self, requests: RequestList) -> bool: - fitting_requests, _, _ = self.capacity_scheduler.schedule_request( - requests) + fitting_requests, _, _ = self.capacity_scheduler.schedule_request(requests) return len(fitting_requests) == len(requests) @@ -316,15 +322,13 @@ class MicroBatchScheduler: class PyMicroBatchScheduler(MicroBatchScheduler): - def __init__( self, max_batch_size: int, max_num_tokens: Optional[int] = None, ctx_chunk_config: Optional[ContextChunkingConfig] = None, no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, - no_schedule_after_state: LlmRequestState = LlmRequestState. - GENERATION_TO_COMPLETE, + no_schedule_after_state: LlmRequestState = LlmRequestState.GENERATION_TO_COMPLETE, ): super().__init__() self.max_batch_size = max_batch_size @@ -349,13 +353,14 @@ class PyMicroBatchScheduler(MicroBatchScheduler): # Use state_value property (returns int directly, avoids enum object creation) state_value = req.state_value # Inline comparison: must have reached until_state but not after_state - return (state_value >= self._no_schedule_until_state_value - and state_value < self._no_schedule_after_state_value) + return ( + state_value >= self._no_schedule_until_state_value + and state_value < self._no_schedule_after_state_value + ) def schedule( - self, active_requests: RequestList, - inflight_request_ids: set[int]) -> tuple[RequestList, RequestList]: - + self, active_requests: RequestList, inflight_request_ids: set[int] + ) -> tuple[RequestList, RequestList]: context_requests: RequestList = [] generation_requests: RequestList = [] @@ -382,9 +387,12 @@ class PyMicroBatchScheduler(MicroBatchScheduler): if req.request_id in inflight_request_ids: continue - # Skip if request cannot be scheduled yet or should no longer be scheduled, manually inline the condition to reuse req.state_value - if not (req_state_value >= self._no_schedule_until_state_value - and req_state_value < self._no_schedule_after_state_value): + # Skip if request cannot be scheduled yet or should no longer be scheduled, + # manually inline the condition to reuse req.state_value + if not ( + req_state_value >= self._no_schedule_until_state_value + and req_state_value < self._no_schedule_after_state_value + ): continue req_num_tokens = 0 @@ -393,11 +401,13 @@ class PyMicroBatchScheduler(MicroBatchScheduler): if req_state_value == self._encoder_init_state_value: req_num_tokens = req.encoder_output_len - assert max_context_length is None or req_num_tokens <= max_context_length, \ + assert max_context_length is None or req_num_tokens <= max_context_length, ( f"The number of encoder tokens ({req_num_tokens}) exceeds the limit value ({max_context_length})" + ) if max_num_tokens is not None and ( - batch_num_tokens + req_num_tokens > max_num_tokens): + batch_num_tokens + req_num_tokens > max_num_tokens + ): break logger.debug(f"encoder request scheduled: ID {req.request_id}") @@ -413,24 +423,27 @@ class PyMicroBatchScheduler(MicroBatchScheduler): draft_tokens = req.num_draft_tokens if req.has_draft_tokens else 0 req_num_tokens = base_tokens + draft_tokens - assert max_context_length is None or req_num_tokens <= max_context_length, \ - f"The number of context tokens ({req_num_tokens}) exceeds the limit value ({max_context_length})" + assert max_context_length is None or req_num_tokens <= max_context_length, ( + f"Context tokens ({req_num_tokens}) exceeds limit ({max_context_length})" + ) if max_num_tokens is not None and ( - batch_num_tokens + req_num_tokens > max_num_tokens): + batch_num_tokens + req_num_tokens > max_num_tokens + ): break - logger.debug( - f"context request scheduled: ID {req.request_id}") + logger.debug(f"context request scheduled: ID {req.request_id}") context_requests.append(req) batch_num_tokens += req_num_tokens else: # Chunking Enabled: Tentative schedule req.context_chunk_size = req.context_remaining_length - draft_tokens = req.num_draft_tokens if ( - req.is_last_context_chunk - and req.has_draft_tokens) else 0 + draft_tokens = ( + req.num_draft_tokens + if (req.is_last_context_chunk and req.has_draft_tokens) + else 0 + ) req_num_tokens = req.context_chunk_size + draft_tokens if max_context_length is not None: @@ -438,9 +451,7 @@ class PyMicroBatchScheduler(MicroBatchScheduler): req_num_tokens = max_context_length all_context_requests_fit = False - logger.debug( - f"contexts-to-be-chunked request scheduled: ID {req.request_id}" - ) + logger.debug(f"contexts-to-be-chunked request scheduled: ID {req.request_id}") contexts_to_be_chunked.append(req) num_chunked_tokens += req_num_tokens @@ -448,12 +459,12 @@ class PyMicroBatchScheduler(MicroBatchScheduler): else: # C++ uses getBeamWidthByIter() which returns dynamic beam width # during beam search (1->2->3->...->beamWidth) - beam_width = req.get_beam_width_by_iter( - for_next_iteration=False) + beam_width = req.get_beam_width_by_iter(for_next_iteration=False) req_num_tokens = beam_width + req.num_draft_tokens if max_num_tokens is not None and ( - batch_num_tokens + req_num_tokens > max_num_tokens): + batch_num_tokens + req_num_tokens > max_num_tokens + ): break # Beam Width Consistency Check @@ -463,7 +474,8 @@ class PyMicroBatchScheduler(MicroBatchScheduler): logger.debug( f"generation request skipped: ID {req.request_id} since its " f"beam width ({beam_width}) is different from scheduled ones " - f"({scheduled_beam_width})") + f"({scheduled_beam_width})" + ) continue generation_requests.append(req) batch_num_tokens += req_num_tokens @@ -474,45 +486,46 @@ class PyMicroBatchScheduler(MicroBatchScheduler): break # 2. Verify Chunking Fits - if max_num_tokens is not None and num_chunked_tokens > ( - max_num_tokens - batch_num_tokens): + if max_num_tokens is not None and num_chunked_tokens > (max_num_tokens - batch_num_tokens): all_context_requests_fit = False # 3. Apply Chunking Strategy if needed if not all_context_requests_fit and contexts_to_be_chunked: - assert ctx_chunk_config is not None, \ + assert ctx_chunk_config is not None, ( "If chunking is not enabled, context scheduling should be completed." + ) remaining_capacity = ( - max_num_tokens - - batch_num_tokens) if max_num_tokens is not None else None + (max_num_tokens - batch_num_tokens) if max_num_tokens is not None else None + ) - self._set_ctx_requests_chunk_size(contexts_to_be_chunked, - remaining_capacity) + self._set_ctx_requests_chunk_size(contexts_to_be_chunked, remaining_capacity) # 4. Finalize Chunked Requests for req in contexts_to_be_chunked: if req.context_chunk_size > 0: context_requests.append(req) batch_num_tokens += req.context_chunk_size - logger.debug(f"context request scheduled: ID {req.request_id}, " - f"chunk size {req.context_chunk_size}") + logger.debug( + f"context request scheduled: ID {req.request_id}, " + f"chunk size {req.context_chunk_size}" + ) # Sort requests for consistency with C++ # C++ reference: utils::sortRequests in inflightBatchingUtils.cpp - self._sort_requests(context_requests, generation_requests, - not all_context_requests_fit) + self._sort_requests(context_requests, generation_requests, not all_context_requests_fit) # Summary logs - logger.debug(f"batchSize (num ctx/enc requests + num gen requests): " - f"{len(context_requests) + len(generation_requests)}") - logger.debug(f"batchNumTokens / maxNumTokens: {batch_num_tokens} / " - f"{max_num_tokens or 0}") + logger.debug( + f"batchSize (num ctx/enc requests + num gen requests): " + f"{len(context_requests) + len(generation_requests)}" + ) + logger.debug(f"batchNumTokens / maxNumTokens: {batch_num_tokens} / {max_num_tokens or 0}") return context_requests, generation_requests - def _sort_requests(self, context_requests: RequestList, - generation_requests: RequestList, - chunks_present: bool) -> None: + def _sort_requests( + self, context_requests: RequestList, generation_requests: RequestList, chunks_present: bool + ) -> None: """ Sort requests for consistency with C++. C++ reference: utils::sortRequests in inflightBatchingUtils.cpp @@ -525,19 +538,15 @@ class PyMicroBatchScheduler(MicroBatchScheduler): def get_lora_task_id(req: LlmRequest): # C++ uses std::optional comparison where nullopt < any_value # So requests without LoRA (nullopt) should come first - lora_id = getattr(req, 'lora_task_id', None) + lora_id = getattr(req, "lora_task_id", None) if lora_id is None: return (0, 0) # (has_value=False, value=0) - comes first return (1, lora_id) # (has_value=True, value) - sorted by value if chunks_present: # Partition: non-last-chunk first, last-chunk at end - not_last_chunk = [ - r for r in context_requests if not r.is_last_context_chunk - ] - last_chunk = [ - r for r in context_requests if r.is_last_context_chunk - ] + not_last_chunk = [r for r in context_requests if not r.is_last_context_chunk] + last_chunk = [r for r in context_requests if r.is_last_context_chunk] # Sort each group by lora_task_id not_last_chunk.sort(key=get_lora_task_id) last_chunk.sort(key=get_lora_task_id) @@ -550,8 +559,7 @@ class PyMicroBatchScheduler(MicroBatchScheduler): generation_requests.sort(key=get_lora_task_id) - def _set_ctx_requests_chunk_size(self, requests: RequestList, - capacity: Optional[int]): + def _set_ctx_requests_chunk_size(self, requests: RequestList, capacity: Optional[int]): # C++: Resets all chunk sizes to 0 at start for req in requests: req.context_chunk_size = 0 @@ -568,14 +576,12 @@ class PyMicroBatchScheduler(MicroBatchScheduler): self._fit_draft_tokens(requests, capacity, unit_size) - def _chunk_equal_progress(self, requests: RequestList, - capacity: Optional[int], unit_size: int): + def _chunk_equal_progress(self, requests: RequestList, capacity: Optional[int], unit_size: int): num_ctx_tokens = 0 num_tokens_single_loop = 1 # C++ Loop: while ((!capacity || numCtxTokens < capacity) && numTokensSingleLoop) - while (capacity is None - or num_ctx_tokens < capacity) and num_tokens_single_loop > 0: + while (capacity is None or num_ctx_tokens < capacity) and num_tokens_single_loop > 0: num_tokens_single_loop = 0 for req in requests: past_size = req.context_chunk_size @@ -594,8 +600,7 @@ class PyMicroBatchScheduler(MicroBatchScheduler): # Check Constraints # 1. Capacity - if capacity is not None and (num_ctx_tokens + actual_increment - > capacity): + if capacity is not None and (num_ctx_tokens + actual_increment > capacity): req.context_chunk_size = past_size # Revert continue @@ -607,9 +612,8 @@ class PyMicroBatchScheduler(MicroBatchScheduler): num_ctx_tokens += actual_increment num_tokens_single_loop += actual_increment - def _chunk_fcfs(self, requests: RequestList, capacity: Optional[int], - unit_size: int): - current_capacity = capacity if capacity is not None else float('inf') + def _chunk_fcfs(self, requests: RequestList, capacity: Optional[int], unit_size: int): + current_capacity = capacity if capacity is not None else float("inf") for req in requests: suggested_size = req.context_remaining_length @@ -631,8 +635,7 @@ class PyMicroBatchScheduler(MicroBatchScheduler): if capacity is not None: current_capacity -= req.context_chunk_size - def _fit_draft_tokens(self, requests: RequestList, capacity: Optional[int], - unit_size: int): + def _fit_draft_tokens(self, requests: RequestList, capacity: Optional[int], unit_size: int): # Calculate tokens already taken by the batch so far num_ctx_tokens = sum(req.context_chunk_size for req in requests) @@ -643,12 +646,10 @@ class PyMicroBatchScheduler(MicroBatchScheduler): if self.max_context_length is not None: remaining_context_len = self.max_context_length - req.context_chunk_size - remaining_space = min(remaining_space, - remaining_context_len) + remaining_space = min(remaining_space, remaining_context_len) if capacity is not None: - remaining_space = min(remaining_space, - capacity - num_ctx_tokens) + remaining_space = min(remaining_space, capacity - num_ctx_tokens) num_ctx_tokens += remaining_space draft_discard = req.num_draft_tokens - remaining_space @@ -666,8 +667,8 @@ class SchedulerPolicyBase(ABC): @abstractmethod def schedule( - self, scheduler: 'PyCapacityScheduler', - active_requests: RequestList) -> tuple[RequestList, RequestList]: + self, scheduler: "PyCapacityScheduler", active_requests: RequestList + ) -> tuple[RequestList, RequestList]: """ Schedule requests according to the policy. @@ -688,8 +689,8 @@ class MaxRequestsPolicy(SchedulerPolicyBase): """ def schedule( - self, scheduler: 'PyCapacityScheduler', - active_requests: RequestList) -> tuple[RequestList, RequestList]: + self, scheduler: "PyCapacityScheduler", active_requests: RequestList + ) -> tuple[RequestList, RequestList]: scheduled_requests: RequestList = [] for req in active_requests: @@ -699,8 +700,11 @@ class MaxRequestsPolicy(SchedulerPolicyBase): if len(scheduled_requests) >= scheduler.max_num_requests: break - if (req.is_encoder_init_state or req.is_context_init_state - or req.is_generation_in_progress_state): + if ( + req.is_encoder_init_state + or req.is_context_init_state + or req.is_generation_in_progress_state + ): scheduled_requests.append(req) return scheduled_requests, [] @@ -716,8 +720,8 @@ class GuaranteedNoEvictPolicy(SchedulerPolicyBase): self.static_batch = static_batch def schedule( - self, scheduler: 'PyCapacityScheduler', - active_requests: RequestList) -> tuple[RequestList, RequestList]: + self, scheduler: "PyCapacityScheduler", active_requests: RequestList + ) -> tuple[RequestList, RequestList]: scheduled_requests: RequestList = [] has_peft = scheduler.peft_cache_manager is not None @@ -726,20 +730,18 @@ class GuaranteedNoEvictPolicy(SchedulerPolicyBase): newly_contributed_context_blocks: Set = set() newly_contributed_cross_context_blocks: Set = set() if not self.static_batch and skipping_is_relevant: - newly_contributed_context_blocks, newly_contributed_cross_context_blocks = \ + newly_contributed_context_blocks, newly_contributed_cross_context_blocks = ( scheduler._prefill_contributed_blocks(active_requests) + ) - reserved_blocks = NoEvictScheduledBlocksManager( - scheduler.kv_cache_manager) + reserved_blocks = NoEvictScheduledBlocksManager(scheduler.kv_cache_manager) reserved_cross_blocks: Optional[NoEvictScheduledBlocksManager] = None if scheduler.cross_kv_cache_manager is not None: - reserved_cross_blocks = NoEvictScheduledBlocksManager( - scheduler.cross_kv_cache_manager) + reserved_cross_blocks = NoEvictScheduledBlocksManager(scheduler.cross_kv_cache_manager) # PEFT state - only used when has_peft claimed_peft_pages = 0 - available_peft_pages = scheduler._get_max_peft_pages( - ) if has_peft else 0 + available_peft_pages = scheduler._get_max_peft_pages() if has_peft else 0 uniq_task_ids: set[int] = set() if has_peft else None pending_requests: RequestList = [] @@ -761,7 +763,8 @@ class GuaranteedNoEvictPolicy(SchedulerPolicyBase): if has_peft: lora_task_id, is_new_task, peft_pages = scheduler._get_peft_task_info( - req, uniq_task_ids) + req, uniq_task_ids + ) if is_new_task: claimed_peft_pages += peft_pages uniq_task_ids.add(lora_task_id) @@ -778,31 +781,35 @@ class GuaranteedNoEvictPolicy(SchedulerPolicyBase): for requests in [pending_dis_gen_init_requests, pending_requests]: for req in requests: - if (not self.static_batch and skipping_is_relevant - and not req.is_disagg_generation_init_state - and scheduler._beneficial_to_skip( - req, newly_contributed_context_blocks, - newly_contributed_cross_context_blocks)): + if ( + not self.static_batch + and skipping_is_relevant + and not req.is_disagg_generation_init_state + and scheduler._beneficial_to_skip( + req, + newly_contributed_context_blocks, + newly_contributed_cross_context_blocks, + ) + ): continue if len(scheduled_requests) >= scheduler.max_num_requests: break if req.is_context_init_state or req.is_disagg_generation_init_state: - enough_blocks = reserved_blocks.enough_available_blocks( - req) + enough_blocks = reserved_blocks.enough_available_blocks(req) enough_cross_blocks = True if reserved_cross_blocks is not None: - enough_cross_blocks = reserved_cross_blocks.enough_available_blocks( - req) + enough_cross_blocks = reserved_cross_blocks.enough_available_blocks(req) if not enough_blocks or not enough_cross_blocks: break # PEFT check only when needed if has_peft: - lora_task_id, is_new_task, needed_peft_pages = scheduler._get_peft_task_info( - req, uniq_task_ids) + lora_task_id, is_new_task, needed_peft_pages = ( + scheduler._get_peft_task_info(req, uniq_task_ids) + ) if needed_peft_pages > available_peft_pages: continue available_peft_pages -= needed_peft_pages @@ -824,27 +831,27 @@ class MaxUtilizationPolicy(SchedulerPolicyBase): """ def schedule( - self, scheduler: 'PyCapacityScheduler', - active_requests: RequestList) -> tuple[RequestList, RequestList]: + self, scheduler: "PyCapacityScheduler", active_requests: RequestList + ) -> tuple[RequestList, RequestList]: scheduler.kv_cache_manager.start_scheduling() skipping_is_relevant = scheduler._is_skipping_relevant() scheduled_blocks_manager = MaxUtilizationScheduledBlocksManager( - scheduler.kv_cache_manager, scheduler.two_step_lookahead) + scheduler.kv_cache_manager, scheduler.two_step_lookahead + ) num_scheduled_peft_pages = 0 seen_task_ids: set[int] = set() - newly_contributed_context_blocks, _ = scheduler._prefill_contributed_blocks( - active_requests) + newly_contributed_context_blocks, _ = scheduler._prefill_contributed_blocks(active_requests) def is_started_request(req: LlmRequest) -> bool: if not scheduler._can_be_scheduled(req): return False - return ((req.is_context_init_state - and not req.is_first_context_chunk) - or req.is_generation_in_progress_state) + return ( + req.is_context_init_state and not req.is_first_context_chunk + ) or req.is_generation_in_progress_state scheduled_requests: RequestList = [] paused_requests: RequestList = [] @@ -855,30 +862,33 @@ class MaxUtilizationPolicy(SchedulerPolicyBase): while req_it < req_it_end: req = requests_list[req_it] - logger.debug( - f"MaxUtilizationScheduler: scheduling request ID {req.request_id}" - ) + logger.debug(f"MaxUtilizationScheduler: scheduling request ID {req.request_id}") if not scheduler._can_be_scheduled_with_disagg_exception(req): logger.debug( f"MaxUtilizationScheduler: request ID {req.request_id} " - "cannot / should not be scheduled") + "cannot / should not be scheduled" + ) req_it += 1 continue - if (skipping_is_relevant and scheduler._beneficial_to_skip( - req, newly_contributed_context_blocks, set())): + if skipping_is_relevant and scheduler._beneficial_to_skip( + req, newly_contributed_context_blocks, set() + ): req_it += 1 continue was_scheduled = self._try_scheduling_request( - scheduler, req, scheduled_requests, scheduled_blocks_manager, - num_scheduled_peft_pages, seen_task_ids) + scheduler, + req, + scheduled_requests, + scheduled_blocks_manager, + num_scheduled_peft_pages, + seen_task_ids, + ) if was_scheduled: - logger.debug( - f"MaxUtilizationScheduler: request ID {req.request_id} -> start" - ) + logger.debug(f"MaxUtilizationScheduler: request ID {req.request_id} -> start") req_it += 1 else: last_started_idx = None @@ -889,8 +899,7 @@ class MaxUtilizationPolicy(SchedulerPolicyBase): if last_started_idx is not None: paused_req = requests_list[last_started_idx] - scheduler.kv_cache_manager.scheduling_remove_sequence( - paused_req.py_request_id) + scheduler.kv_cache_manager.scheduling_remove_sequence(paused_req.py_request_id) paused_requests.append(paused_req) logger.debug( f"MaxUtilizationScheduler: request ID {paused_req.request_id} -> pause" @@ -902,25 +911,30 @@ class MaxUtilizationPolicy(SchedulerPolicyBase): return scheduled_requests, paused_requests def _try_scheduling_request( - self, scheduler: 'PyCapacityScheduler', req: LlmRequest, - scheduled_requests: RequestList, - scheduled_blocks_manager: 'MaxUtilizationScheduledBlocksManager', - num_scheduled_peft_pages: int, seen_task_ids: set[int]) -> bool: + self, + scheduler: "PyCapacityScheduler", + req: LlmRequest, + scheduled_requests: RequestList, + scheduled_blocks_manager: "MaxUtilizationScheduledBlocksManager", + num_scheduled_peft_pages: int, + seen_task_ids: set[int], + ) -> bool: if len(scheduled_requests) >= scheduler.max_num_requests: return False - blocks_if_scheduled = scheduled_blocks_manager.prepare_blocks_if_schedulable( - req) + blocks_if_scheduled = scheduled_blocks_manager.prepare_blocks_if_schedulable(req) if blocks_if_scheduled is None: return False # PEFT check only when needed if scheduler.peft_cache_manager is not None: lora_task_id, is_new_task, num_required_peft_pages = scheduler._get_peft_task_info( - req, seen_task_ids) + req, seen_task_ids + ) logger.debug( f"MaxUtilizationScheduler: request ID {req.request_id} " - f"required peft pages: {num_required_peft_pages}") + f"required peft pages: {num_required_peft_pages}" + ) max_peft_pages = scheduler._get_max_peft_pages() if num_required_peft_pages + num_scheduled_peft_pages > max_peft_pages: return False @@ -950,8 +964,7 @@ class NoEvictScheduledBlocksManager: """ self.kv_cache_manager = kv_cache_manager stats = kv_cache_manager.get_kv_cache_stats() - self.available_blocks: dict[int, int] = dict( - stats.num_free_blocks_per_window_size) + self.available_blocks: dict[int, int] = dict(stats.num_free_blocks_per_window_size) def decrement_reserved_blocks(self, req: LlmRequest) -> None: """ @@ -959,8 +972,7 @@ class NoEvictScheduledBlocksManager: C++ reference: scheduledBlocksManager.h:40-46 """ for window_size in self.available_blocks: - needed = self.kv_cache_manager.get_remaining_blocks_to_completion( - req, window_size) + needed = self.kv_cache_manager.get_remaining_blocks_to_completion(req, window_size) self.available_blocks[window_size] -= needed def enough_available_blocks(self, req: LlmRequest) -> bool: @@ -969,8 +981,9 @@ class NoEvictScheduledBlocksManager: C++ reference: scheduledBlocksManager.h:48-57 """ return all( - self.kv_cache_manager.get_remaining_blocks_to_completion(req, ws) <= - avail for ws, avail in self.available_blocks.items()) + self.kv_cache_manager.get_remaining_blocks_to_completion(req, ws) <= avail + for ws, avail in self.available_blocks.items() + ) class MaxUtilizationScheduledBlocksManager: @@ -989,13 +1002,9 @@ class MaxUtilizationScheduledBlocksManager: self.kv_cache_manager = kv_cache_manager self.two_steps_look_ahead = two_steps_look_ahead window_sizes = set(kv_cache_manager.max_attention_window_vec) - self.num_scheduled_blocks: dict[int, int] = { - ws: 0 - for ws in window_sizes - } + self.num_scheduled_blocks: dict[int, int] = {ws: 0 for ws in window_sizes} - def prepare_blocks_if_schedulable( - self, req: LlmRequest) -> Optional[dict[int, int]]: + def prepare_blocks_if_schedulable(self, req: LlmRequest) -> Optional[dict[int, int]]: """ Check if request can be scheduled and return new block counts if so. Returns None if request cannot fit. @@ -1004,13 +1013,16 @@ class MaxUtilizationScheduledBlocksManager: blocks_if_scheduled = {} for window_size, num_scheduled in self.num_scheduled_blocks.items(): required = self.kv_cache_manager.get_needed_blocks_one_step( - req, self.two_steps_look_ahead, window_size) + req, self.two_steps_look_ahead, window_size + ) logger.debug( f"MaxUtilizationScheduler: request ID {req.request_id} " - f"required blocks {required} for {window_size} window size") + f"required blocks {required} for {window_size} window size" + ) scheduled_total = num_scheduled + required has_free = self.kv_cache_manager.scheduling_has_free_blocks( - scheduled_total, window_size) + scheduled_total, window_size + ) if not has_free: return None blocks_if_scheduled[window_size] = scheduled_total @@ -1021,12 +1033,14 @@ class MaxUtilizationScheduledBlocksManager: Update the scheduled blocks after successfully scheduling a request. C++ reference: scheduledBlocksManager.h:102-110 """ - assert len(blocks) == len(self.num_scheduled_blocks), \ + assert len(blocks) == len(self.num_scheduled_blocks), ( f"Block count mismatch: {len(blocks)} vs {len(self.num_scheduled_blocks)}" + ) for window_size, blocks_if_scheduled in blocks.items(): logger.debug( f"MaxUtilizationScheduler: scheduled blocks {blocks_if_scheduled} " - f"for window size {window_size}") + f"for window size {window_size}" + ) self.num_scheduled_blocks[window_size] = blocks_if_scheduled @@ -1050,13 +1064,11 @@ class PyCapacityScheduler: max_num_requests: int, kv_cache_manager=None, peft_cache_manager=None, - scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. - GUARANTEED_NO_EVICT, + scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, cross_kv_cache_manager=None, two_step_lookahead: bool = False, no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, - no_schedule_after_state: LlmRequestState = LlmRequestState. - GENERATION_COMPLETE, + no_schedule_after_state: LlmRequestState = LlmRequestState.GENERATION_COMPLETE, ): """ Initialize the capacity scheduler. @@ -1097,8 +1109,7 @@ class PyCapacityScheduler: elif self.scheduler_policy == CapacitySchedulerPolicy.STATIC_BATCH: return GuaranteedNoEvictPolicy(static_batch=True) else: - raise ValueError( - f"Unsupported scheduler policy: {self.scheduler_policy}") + raise ValueError(f"Unsupported scheduler policy: {self.scheduler_policy}") def _can_be_scheduled(self, req: LlmRequest) -> bool: """ @@ -1110,8 +1121,10 @@ class PyCapacityScheduler: # Use state_value property (returns int directly, avoids enum object creation) state_value = req.state_value # Inline comparison: must have reached until_state but not after_state - return (state_value >= self._no_schedule_until_state_value - and state_value < self._no_schedule_after_state_value) + return ( + state_value >= self._no_schedule_until_state_value + and state_value < self._no_schedule_after_state_value + ) def _is_skipping_relevant(self) -> bool: """ @@ -1123,13 +1136,14 @@ class PyCapacityScheduler: return False if self.kv_cache_manager.is_variable_window: return False - if (self.cross_kv_cache_manager is not None - and self.cross_kv_cache_manager.is_variable_window): + if ( + self.cross_kv_cache_manager is not None + and self.cross_kv_cache_manager.is_variable_window + ): return False return True - def _prefill_contributed_blocks( - self, active_requests: RequestList) -> tuple[set, set]: + def _prefill_contributed_blocks(self, active_requests: RequestList) -> tuple[set, set]: """ Collect blocks contributed by chunked context requests already executing. These blocks can be reused by later requests. @@ -1143,8 +1157,10 @@ class PyCapacityScheduler: return newly_contributed_context_blocks, newly_contributed_cross_context_blocks enable_block_reuse = self.kv_cache_manager.enable_block_reuse - cross_enable_reuse = (self.cross_kv_cache_manager is not None and - self.cross_kv_cache_manager.enable_block_reuse) + cross_enable_reuse = ( + self.cross_kv_cache_manager is not None + and self.cross_kv_cache_manager.enable_block_reuse + ) for req in active_requests: # Check: isContextInitState() && !isFirstContextChunk() @@ -1152,8 +1168,7 @@ class PyCapacityScheduler: # Chunked context request already executing if enable_block_reuse: unique_tokens = req.get_unique_tokens(0) - block_key = self.kv_cache_manager.find_new_context_block( - unique_tokens, req) + block_key = self.kv_cache_manager.find_new_context_block(unique_tokens, req) if block_key is not None: newly_contributed_context_blocks.add(block_key) @@ -1161,22 +1176,21 @@ class PyCapacityScheduler: encoder_unique_tokens = req.get_encoder_unique_tokens() if encoder_unique_tokens is not None: block_key = self.cross_kv_cache_manager.find_new_context_block( - encoder_unique_tokens, req) + encoder_unique_tokens, req + ) if block_key is not None: - newly_contributed_cross_context_blocks.add( - block_key) + newly_contributed_cross_context_blocks.add(block_key) return newly_contributed_context_blocks, newly_contributed_cross_context_blocks - def _one_manager_beneficial_to_skip(self, kv_cache_manager, unique_tokens, - req: LlmRequest, - newly_contributed_blocks: set) -> bool: + def _one_manager_beneficial_to_skip( + self, kv_cache_manager, unique_tokens, req: LlmRequest, newly_contributed_blocks: set + ) -> bool: """ Check if skipping is beneficial for one KV cache manager. C++ reference: capacityScheduler.cpp:70-92 (oneManagerBeneficialToSkip) """ - new_context_block = kv_cache_manager.find_new_context_block( - unique_tokens, req) + new_context_block = kv_cache_manager.find_new_context_block(unique_tokens, req) if new_context_block is not None: if new_context_block in newly_contributed_blocks: return True @@ -1184,8 +1198,11 @@ class PyCapacityScheduler: return False def _beneficial_to_skip( - self, req: LlmRequest, newly_contributed_context_blocks: set, - newly_contributed_cross_context_blocks: set) -> bool: + self, + req: LlmRequest, + newly_contributed_context_blocks: set, + newly_contributed_cross_context_blocks: set, + ) -> bool: """ Check if it's beneficial to skip this request. A request should be skipped if it can reuse blocks contributed by @@ -1196,21 +1213,25 @@ class PyCapacityScheduler: if not (req.is_context_init_state and req.is_first_context_chunk): return False - if (self.kv_cache_manager is not None - and self.kv_cache_manager.enable_block_reuse): + if self.kv_cache_manager is not None and self.kv_cache_manager.enable_block_reuse: unique_tokens = req.get_unique_tokens(0) if self._one_manager_beneficial_to_skip( - self.kv_cache_manager, unique_tokens, req, - newly_contributed_context_blocks): + self.kv_cache_manager, unique_tokens, req, newly_contributed_context_blocks + ): return True - if (self.cross_kv_cache_manager is not None - and self.cross_kv_cache_manager.enable_block_reuse): + if ( + self.cross_kv_cache_manager is not None + and self.cross_kv_cache_manager.enable_block_reuse + ): encoder_unique_tokens = req.get_encoder_unique_tokens() if encoder_unique_tokens is not None: if self._one_manager_beneficial_to_skip( - self.cross_kv_cache_manager, encoder_unique_tokens, req, - newly_contributed_cross_context_blocks): + self.cross_kv_cache_manager, + encoder_unique_tokens, + req, + newly_contributed_cross_context_blocks, + ): return True return False @@ -1228,16 +1249,15 @@ class PyCapacityScheduler: return self.peft_cache_manager.determine_num_pages(req) def _get_peft_task_info( - self, req: LlmRequest, - seen_task_ids: set[int]) -> tuple[Optional[int], bool, int]: + self, req: LlmRequest, seen_task_ids: set[int] + ) -> tuple[Optional[int], bool, int]: """ Get PEFT task information for a request. Returns (lora_task_id, is_new_task, required_pages). """ - lora_task_id = getattr(req, 'lora_task_id', None) + lora_task_id = getattr(req, "lora_task_id", None) is_new_task = lora_task_id is not None and lora_task_id not in seen_task_ids - required_pages = self._get_peft_pages_for_request( - req) if is_new_task else 0 + required_pages = self._get_peft_pages_for_request(req) if is_new_task else 0 return lora_task_id, is_new_task, required_pages def _can_be_scheduled_with_disagg_exception(self, req: LlmRequest) -> bool: @@ -1265,18 +1285,16 @@ class PyCapacityScheduler: """ scheduled, paused = self._policy.schedule(self, active_requests) - fitting_requests, fitting_disagg_gen_init_requests = self._classify_output( - scheduled) + fitting_requests, fitting_disagg_gen_init_requests = self._classify_output(scheduled) logger.debug( f"[Summary] Capacity scheduler allows {len(fitting_requests)} requests, " - f"pauses {len(paused)} requests") + f"pauses {len(paused)} requests" + ) return fitting_requests, fitting_disagg_gen_init_requests, paused - def _classify_output( - self, - scheduled_requests: RequestList) -> tuple[RequestList, RequestList]: + def _classify_output(self, scheduled_requests: RequestList) -> tuple[RequestList, RequestList]: """ Separate scheduled requests into normal requests and disagg gen init requests. C++ reference: capacityScheduler.cpp:522-534 @@ -1292,7 +1310,6 @@ class PyCapacityScheduler: class SimpleUnifiedScheduler(RequestScheduler): - def __init__( self, max_batch_size: int, @@ -1317,7 +1334,8 @@ class SimpleUnifiedScheduler(RequestScheduler): peft_cache_manager=peft_cache_manager, scheduler_policy=scheduler_policy, cross_kv_cache_manager=cross_kv_cache_manager, - two_step_lookahead=two_step_lookahead) + two_step_lookahead=two_step_lookahead, + ) # 2. Initialize Python MicroBatch Scheduler py_chunk_config = None @@ -1332,30 +1350,34 @@ class SimpleUnifiedScheduler(RequestScheduler): # Default to FCFS for FIRST_COME_FIRST_SERVED or others policy_enum = ChunkingPolicy.FIRST_COME_FIRST_SERVED - py_chunk_config = ContextChunkingConfig(policy_enum, - ctx_chunk_config[1]) + py_chunk_config = ContextChunkingConfig(policy_enum, ctx_chunk_config[1]) self.micro_batch_scheduler = PyMicroBatchScheduler( max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, - ctx_chunk_config=py_chunk_config) + ctx_chunk_config=py_chunk_config, + ) - def schedule_request(self, active_requests: RequestList, - inflight_request_ids: set[int]) -> SchedulerOutput: + def schedule_request( + self, active_requests: RequestList, inflight_request_ids: set[int] + ) -> SchedulerOutput: # Step 1: Capacity Check (Who fits in memory?) - fitting_requests, fitting_disagg_gen_init, paused_requests = \ + fitting_requests, fitting_disagg_gen_init, paused_requests = ( self.capacity_scheduler.schedule_request(active_requests) + ) # Step 2: MicroBatch Check (Who fits in token budget? + Chunking) - context_requests, generation_requests = \ - self.micro_batch_scheduler.schedule(fitting_requests, inflight_request_ids) + context_requests, generation_requests = self.micro_batch_scheduler.schedule( + fitting_requests, inflight_request_ids + ) return SchedulerOutput( context_requests=context_requests, generation_requests=generation_requests, paused_requests=paused_requests, fitting_disagg_gen_init_requests=fitting_disagg_gen_init, - num_fitting_requests=len(fitting_requests)) + num_fitting_requests=len(fitting_requests), + ) def can_schedule(self, requests: RequestList) -> bool: # Dry run capacity check diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py b/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py new file mode 100644 index 0000000000..f35db02471 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/waiting_queue.py @@ -0,0 +1,134 @@ +from abc import ABC, abstractmethod +from collections import deque +from collections.abc import Iterable, Iterator +from typing import Callable, Optional + +from tensorrt_llm.llmapi.llm_args import WaitingQueuePolicy + +from ..executor_request_queue import RequestQueueItem + + +class WaitingQueue(ABC): + """Abstract base class for waiting queues.""" + + @abstractmethod + def add_request(self, request: RequestQueueItem) -> None: + """Add a request to the queue according to the policy.""" + pass + + @abstractmethod + def add_requests(self, requests: Iterable[RequestQueueItem]) -> None: + """Add multiple requests to the queue according to the policy.""" + pass + + @abstractmethod + def pop_request(self) -> RequestQueueItem: + """Pop a request from the queue according to the policy.""" + pass + + @abstractmethod + def peek_request(self) -> RequestQueueItem: + """Peek at the request at the front of the queue without removing it.""" + pass + + @abstractmethod + def prepend_request(self, request: RequestQueueItem) -> None: + """Prepend a request to the front of the queue.""" + pass + + @abstractmethod + def prepend_requests(self, requests: Iterable[RequestQueueItem]) -> None: + """Prepend all requests from another iterable to the front of this queue.""" + pass + + @abstractmethod + def remove_by_ids(self, request_ids: set[int]) -> None: + """Remove requests with the given IDs.""" + pass + + @abstractmethod + def __bool__(self) -> bool: + """Check if queue has any requests.""" + pass + + @abstractmethod + def __len__(self) -> int: + """Get number of requests in queue.""" + pass + + @abstractmethod + def __iter__(self) -> Iterator[RequestQueueItem]: + """Iterate over the queue according to the policy.""" + pass + + +class FCFSWaitingQueue(deque, WaitingQueue): + """A first-come-first-served queue that supports deque operations.""" + + def add_request(self, request: RequestQueueItem) -> None: + """Add a request to the queue according to FCFS policy.""" + self.append(request) + + def add_requests(self, requests: Iterable[RequestQueueItem]) -> None: + """Add multiple requests to the queue according to FCFS policy.""" + self.extend(requests) + + def pop_request(self) -> RequestQueueItem: + """Pop a request from the queue according to FCFS policy.""" + return self.popleft() + + def peek_request(self) -> RequestQueueItem: + """Peek at the next request in the queue without removing it.""" + if not self: + raise IndexError("peek from an empty queue") + return self[0] + + def prepend_request(self, request: RequestQueueItem) -> None: + """Prepend a request to the front of the queue.""" + self.appendleft(request) + + def prepend_requests(self, requests: Iterable[RequestQueueItem]) -> None: + """Prepend all requests from another iterable to the front of this queue. + + Note: The requests will be prepended in reverse order of their + appearance in the `requests` iterable. + """ + self.extendleft(requests) + + def remove_by_ids(self, request_ids: set[int]) -> None: + """Remove requests with the given IDs.""" + filtered_requests = [req for req in self if req.id not in request_ids] + self.clear() + self.extend(filtered_requests) + + def __bool__(self) -> bool: + """Check if queue has any requests.""" + return len(self) > 0 + + def __len__(self) -> int: + """Get number of requests in queue.""" + return super().__len__() + + def __iter__(self) -> Iterator[RequestQueueItem]: + """Iterate over the queue according to FCFS policy.""" + return super().__iter__() + + +def create_waiting_queue( + policy: WaitingQueuePolicy = WaitingQueuePolicy.FCFS, + priority_fn: Optional[Callable[[RequestQueueItem], float]] = None, +) -> WaitingQueue: + """Create a waiting queue based on the specified policy. + + Args: + policy: The scheduling policy to use. Currently only FCFS is supported. + priority_fn: Reserved for future use. + + Returns: + A WaitingQueue instance. + """ + # Currently only FCFS is implemented + if policy == WaitingQueuePolicy.FCFS: + return FCFSWaitingQueue() + else: + raise ValueError(f"Unsupported waiting queue policy: {policy}") diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 7fa2d335dc..f5c16b9d9c 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1487,6 +1487,12 @@ class ContextChunkingPolicy(StrEnum, metaclass=PybindMirrorEnumMeta): return getattr(_ContextChunkingPolicy, self.value) +class WaitingQueuePolicy(StrEnum): + """Waiting queue scheduling policy for managing pending requests.""" + + FCFS = "fcfs" # First-Come-First-Served + + @PybindMirror.mirror_pybind_fields(_DynamicBatchConfig) class DynamicBatchConfig(StrictBaseModel, PybindMirror): """Dynamic batch configuration. @@ -1525,6 +1531,10 @@ class SchedulerConfig(StrictBaseModel, PybindMirror): dynamic_batch_config: Optional[DynamicBatchConfig] = Field( default=None, description="The dynamic batch config to use") + waiting_queue_policy: WaitingQueuePolicy = Field( + default=WaitingQueuePolicy.FCFS, + description="The waiting queue scheduling policy") + def _to_pybind(self): return _SchedulerConfig( capacity_scheduler_policy=self.capacity_scheduler_policy._to_pybind( diff --git a/tests/unittest/_torch/executor/test_py_executor.py b/tests/unittest/_torch/executor/test_py_executor.py index 69fbc059e4..9fb5930598 100644 --- a/tests/unittest/_torch/executor/test_py_executor.py +++ b/tests/unittest/_torch/executor/test_py_executor.py @@ -9,7 +9,6 @@ to PyExecutor, including: - expected_num_active_requests tracking """ -from collections import deque from unittest.mock import Mock import pytest @@ -18,6 +17,7 @@ from tensorrt_llm._torch.pyexecutor.executor_request_queue import ( SHUTDOWN_REQUEST_ID, RequestQueueItem, ) +from tensorrt_llm._torch.pyexecutor.scheduler import FCFSWaitingQueue class MockPyExecutor: @@ -35,7 +35,7 @@ class MockPyExecutor: self.is_shutdown = False self.expected_num_active_requests = 0 self.new_active_requests_queue_latency_ms = 0.0 - self.waiting_queue = deque() + self.waiting_queue = FCFSWaitingQueue() def _handle_special_queue_items(self, new_requests): """Handle special signals. @@ -62,13 +62,11 @@ class MockPyExecutor: def update_waiting_queue(self): """Update waiting queue to remove canceled requests. - This method mirrors PyExecutor.update_waiting_queue. + This method mirrors PyExecutor._handle_canceled_requests. """ if self.canceled_req_ids: canceled_set = set(self.canceled_req_ids) - self.waiting_queue = deque( - item for item in self.waiting_queue if item.id not in canceled_set - ) + self.waiting_queue.remove_by_ids(canceled_set) def clear_canceled_req_ids(self): """Clear the list of canceled request IDs.""" diff --git a/tests/unittest/_torch/executor/test_request_utils.py b/tests/unittest/_torch/executor/test_request_utils.py index ed209c6630..10c1e04786 100644 --- a/tests/unittest/_torch/executor/test_request_utils.py +++ b/tests/unittest/_torch/executor/test_request_utils.py @@ -6,7 +6,6 @@ This module tests: - Waiting queue functions (get_from_waiting_queue, can_process_attention_dp_request) """ -from collections import deque from unittest.mock import Mock, patch import pytest @@ -20,6 +19,7 @@ from tensorrt_llm._torch.pyexecutor.request_utils import ( merge_requests, schedule_attention_dp_requests, ) +from tensorrt_llm._torch.pyexecutor.scheduler import FCFSWaitingQueue from tensorrt_llm.bindings import executor as trtllm from tensorrt_llm.mapping import CpType @@ -263,7 +263,7 @@ def test_merge_requests_with_helix_cp_config(): def test_get_from_waiting_queue(): """Test getting items from waiting queue.""" # Add items to waiting queue - waiting_queue = deque() + waiting_queue = FCFSWaitingQueue() items = [RequestQueueItem(i, Mock()) for i in range(5)] waiting_queue.extend(items) @@ -291,7 +291,7 @@ def test_get_from_waiting_queue_edge_cases( ): """Test edge cases for getting items from waiting queue.""" # Setup queue - waiting_queue = deque() + waiting_queue = FCFSWaitingQueue() if queue_size > 0: items = [RequestQueueItem(i, Mock()) for i in range(queue_size)] waiting_queue.extend(items) @@ -307,7 +307,7 @@ def test_get_from_waiting_queue_edge_cases( def test_get_from_waiting_queue_with_attention_dp( attention_dp_config, all_ranks_num_active_requests ): - waiting_queue = deque() + waiting_queue = FCFSWaitingQueue() items = [RequestQueueItem(i, Mock()) for i in range(5)] waiting_queue.extend(items) @@ -338,7 +338,8 @@ def test_get_from_waiting_queue_with_attention_dp_filtering( 3, create_mock_request_with_py_schedule_params(attention_dp_rank=None) ) # No scheduling params - waiting_queue = deque([req1, req2, req3]) + waiting_queue = FCFSWaitingQueue() + waiting_queue.extend([req1, req2, req3]) # Set rank 0 to full capacity to test filtering all_ranks_num_active_requests[0] = 8 @@ -719,7 +720,8 @@ def test_achieve_max_num_active_requests(attention_dp_config): req_id += 1 all_ranks_num_active_requests = [5, 6, 3, 7] - waiting_queue = deque(req_list) + waiting_queue = FCFSWaitingQueue() + waiting_queue.extend(req_list) available_active_requests = max_num_active_requests * 4 - sum(all_ranks_num_active_requests) result = get_from_waiting_queue( @@ -843,7 +845,7 @@ def test_attention_dp_scheduling_cases( all_ranks_expected_req_ids, ): """Test attention DP scheduling with various scenarios.""" - waiting_queue = deque() + waiting_queue = FCFSWaitingQueue() for rank, relax in request_configs: append_to_waiting_queue(waiting_queue, rank, relax) diff --git a/tests/unittest/_torch/executor/test_waiting_queue.py b/tests/unittest/_torch/executor/test_waiting_queue.py new file mode 100644 index 0000000000..d0dab00c94 --- /dev/null +++ b/tests/unittest/_torch/executor/test_waiting_queue.py @@ -0,0 +1,189 @@ +"""Tests for WaitingQueue implementations. + +This module tests the waiting queue functionality including: +- FCFSWaitingQueue operations +- WaitingQueue abstract interface +- create_waiting_queue factory function +""" + +from unittest.mock import Mock + +import pytest + +from tensorrt_llm._torch.pyexecutor.executor_request_queue import RequestQueueItem +from tensorrt_llm._torch.pyexecutor.scheduler import ( + FCFSWaitingQueue, + WaitingQueue, + create_waiting_queue, +) +from tensorrt_llm.llmapi.llm_args import WaitingQueuePolicy + + +def create_mock_request_item(request_id: int) -> RequestQueueItem: + """Create a mock RequestQueueItem for testing.""" + mock_request = Mock() + return RequestQueueItem(request_id, mock_request) + + +class TestFCFSWaitingQueue: + """Tests for FCFSWaitingQueue.""" + + def test_add_request(self): + """Test adding a single request.""" + queue = FCFSWaitingQueue() + item = create_mock_request_item(1) + + queue.add_request(item) + + assert len(queue) == 1 + assert queue.peek_request() == item + + def test_add_requests(self): + """Test adding multiple requests.""" + queue = FCFSWaitingQueue() + items = [create_mock_request_item(i) for i in range(3)] + + queue.add_requests(items) + + assert len(queue) == 3 + + def test_pop_request_fcfs_order(self): + """Test that pop_request returns requests in FCFS order.""" + queue = FCFSWaitingQueue() + items = [create_mock_request_item(i) for i in range(3)] + queue.add_requests(items) + + # Should pop in order: 0, 1, 2 + assert queue.pop_request().id == 0 + assert queue.pop_request().id == 1 + assert queue.pop_request().id == 2 + + def test_pop_from_empty_queue(self): + """Test that pop_request raises IndexError on empty queue.""" + queue = FCFSWaitingQueue() + + with pytest.raises(IndexError): + queue.pop_request() + + def test_peek_request(self): + """Test peeking at the front of the queue.""" + queue = FCFSWaitingQueue() + items = [create_mock_request_item(i) for i in range(3)] + queue.add_requests(items) + + # Peek should return first item without removing it + assert queue.peek_request().id == 0 + assert len(queue) == 3 # Size unchanged + + def test_peek_from_empty_queue(self): + """Test that peek_request raises IndexError on empty queue.""" + queue = FCFSWaitingQueue() + + with pytest.raises(IndexError): + queue.peek_request() + + def test_prepend_request(self): + """Test prepending a request to the front.""" + queue = FCFSWaitingQueue() + queue.add_request(create_mock_request_item(1)) + queue.add_request(create_mock_request_item(2)) + + # Prepend item 0 to front + queue.prepend_request(create_mock_request_item(0)) + + # Should pop in order: 0, 1, 2 + assert queue.pop_request().id == 0 + assert queue.pop_request().id == 1 + assert queue.pop_request().id == 2 + + def test_prepend_requests(self): + """Test prepending multiple requests.""" + queue = FCFSWaitingQueue() + queue.add_request(create_mock_request_item(3)) + + # Prepend items [1, 2] - note: extendleft reverses order + queue.prepend_requests([create_mock_request_item(i) for i in [1, 2]]) + + # After extendleft([1, 2]), order is: 2, 1, 3 + assert queue.pop_request().id == 2 + assert queue.pop_request().id == 1 + assert queue.pop_request().id == 3 + + def test_remove_by_ids(self): + """Test removing requests by their IDs.""" + queue = FCFSWaitingQueue() + items = [create_mock_request_item(i) for i in range(5)] + queue.add_requests(items) + + # Remove items 1 and 3 + queue.remove_by_ids({1, 3}) + + assert len(queue) == 3 + remaining_ids = [item.id for item in queue] + assert remaining_ids == [0, 2, 4] + + def test_remove_nonexistent_ids(self): + """Test removing IDs that don't exist (should not raise).""" + queue = FCFSWaitingQueue() + items = [create_mock_request_item(i) for i in range(3)] + queue.add_requests(items) + + # Remove IDs that don't exist + queue.remove_by_ids({10, 20}) + + assert len(queue) == 3 + + def test_bool_empty_queue(self): + """Test bool conversion for empty queue.""" + queue = FCFSWaitingQueue() + assert not queue + assert bool(queue) is False + + def test_bool_nonempty_queue(self): + """Test bool conversion for non-empty queue.""" + queue = FCFSWaitingQueue() + queue.add_request(create_mock_request_item(1)) + assert queue + assert bool(queue) is True + + def test_len(self): + """Test length of queue.""" + queue = FCFSWaitingQueue() + assert len(queue) == 0 + + queue.add_request(create_mock_request_item(1)) + assert len(queue) == 1 + + queue.add_requests([create_mock_request_item(i) for i in range(2, 5)]) + assert len(queue) == 4 + + def test_iter(self): + """Test iteration over queue.""" + queue = FCFSWaitingQueue() + items = [create_mock_request_item(i) for i in range(3)] + queue.add_requests(items) + + iterated_ids = [item.id for item in queue] + assert iterated_ids == [0, 1, 2] + + # Iteration should not consume items + assert len(queue) == 3 + + def test_is_waiting_queue_subclass(self): + """Test that FCFSWaitingQueue is a WaitingQueue.""" + queue = FCFSWaitingQueue() + assert isinstance(queue, WaitingQueue) + + +class TestCreateWaitingQueue: + """Tests for create_waiting_queue factory function.""" + + def test_create_fcfs_queue(self): + """Test creating FCFS queue.""" + queue = create_waiting_queue(WaitingQueuePolicy.FCFS) + assert isinstance(queue, FCFSWaitingQueue) + + def test_create_default_queue(self): + """Test creating queue with default policy.""" + queue = create_waiting_queue() + assert isinstance(queue, FCFSWaitingQueue)