From 8d9baa462365b390aeb95bebbfe36878daea4c8d Mon Sep 17 00:00:00 2001 From: Jiagan Cheng Date: Tue, 9 Dec 2025 10:43:52 +0800 Subject: [PATCH] [https://nvbugs/5677746][fix] Use first PP rank's schedule result in other PP ranks to fix PP hang (#9659) Signed-off-by: Jiagan Cheng --- .../batch_manager/capacityScheduler.cpp | 6 +- tensorrt_llm/_torch/pyexecutor/py_executor.py | 135 +++++++++++++++--- tensorrt_llm/_torch/pyexecutor/scheduler.py | 70 +++++++++ .../integration/test_lists/test-db/l0_a10.yml | 1 + .../test_scheduler_serializable_output.py | 59 ++++++++ 5 files changed, 248 insertions(+), 23 deletions(-) create mode 100644 tests/unittest/_torch/executor/test_scheduler_serializable_output.py diff --git a/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp b/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp index 9c9c56ba9d..d765bcf317 100644 --- a/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp +++ b/cpp/tensorrt_llm/batch_manager/capacityScheduler.cpp @@ -247,7 +247,8 @@ std::tuple GuaranteedNoEvictScheduler::impl( { break; } - else if (req->isGenerationInProgressState()) + + if (req->isGenerationInProgressState()) { scheduledRequests.emplace_back(req); reservedBlocks.decrementReservedBlocks(*req); @@ -296,7 +297,8 @@ std::tuple GuaranteedNoEvictScheduler::impl( { break; } - else if (req->isContextInitState() || req->isDisaggGenerationInitState()) + + if (req->isContextInitState() || req->isDisaggGenerationInitState()) { bool enoughBlocks = reservedBlocks.enoughAvailableBlocks(*req); bool enoughCrossBlocks diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index f6cf6d4cb5..0178f766dc 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -53,7 +53,8 @@ from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, from .model_engine import ModelEngine from .resource_manager import ResourceManager from .sampler import Sampler, SampleState, SampleStateTensors -from .scheduler import RequestScheduler, ScheduledRequests +from .scheduler import (RequestScheduler, ScheduledRequests, + SerializableSchedulerOutput) # Environment variable to specify iteration ranges for profiling start/stop. # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." @@ -65,6 +66,8 @@ PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE" # Unique tag base to avoid collisions with token/logits comms TERMINATION_COMM_TAG_BASE = 20000 +PP_COMM_TAG_SCHEDULE_RESULT = 21000 +PP_COMM_TAG_SAMPLE_STATE_BASE = 21001 @functools.cache @@ -232,6 +235,10 @@ class PyExecutor: self.micro_batches: List[BatchStatePP | None] = [None] * self.num_micro_batches self.send_handles = [None] * self.num_micro_batches + # schedule handle for PP to propagate the first PP rank's schedule result + self.send_schedule_handler = None + self.pp_scheduler_max_retry_count = int( + os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10)) # Set of request IDs that are currently in flight across all micro batches. # The scheduler will avoid scheduling requests that are already in flight. @@ -786,6 +793,77 @@ class PyExecutor: self.response_cv.notify_all() self.shutdown_event.set() + def _pp_schedule_and_propagate(self): + """The first PP rank schedules the requests and propagates the result to all other PP ranks.""" + + # The first PP rank schedules the requests, other ranks receive the schedule result from the previous PP rank. + if self.dist.is_first_pp_rank: + scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( + ) + serializable_schedule = SerializableSchedulerOutput.from_scheduler_result( + scheduled_batch, fitting_disagg_gen_init_requests, + num_fitting_reqs) + else: + with nvtx_range("recv_schedule_from_prev_pp"): + serializable_schedule = self.dist.recv_object( + self.dist.prev_pp_rank, PP_COMM_TAG_SCHEDULE_RESULT) + scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result( + self.active_requests) + + # Propagate the schedule result to the next PP rank except the last PP rank. + if not self.dist.is_last_pp_rank: + if self.send_schedule_handler is not None: + with nvtx_range("wait_send_schedule_handler"): + self.send_schedule_handler.wait() + with nvtx_range("send_schedule_to_next_pp"): + self.send_schedule_handler = self.dist.isend_object( + serializable_schedule, self.dist.next_pp_rank, + PP_COMM_TAG_SCHEDULE_RESULT) + return scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs + + def _pp_retry_until_can_schedule(self, scheduled_batch): + """ + If current rank cannot run the scheduled batch, it will retry following steps until it has enough KV cache resources or reach maximum retry count: + 1. Wait for cache transceiver to finish at least one cache transmission. + 2. Terminate requests that have finished context cache transmission. + 3. Check if current rank has enough KV cache resources to run the scheduled batch. + """ + scheduled_batch_requests = scheduled_batch.all_requests() + if self.scheduler.can_schedule(scheduled_batch_requests): + return + + logger.warning( + "Cannot run first PP's schedule result due to limited KV cache resources. This may cause bubbles in the PP pipeline. Please consider increasing the KV cache size by setting `free_gpu_memory_fraction` to a larger value." + ) + if self.kv_cache_transceiver is None: + raise RuntimeError( + "KV cache transceiver is not enabled, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected." + ) + if not self.ctx_in_transmission_requests: + raise RuntimeError( + "No context cache transmission is in progress, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected." + ) + if self.block_reuse_enabled and self._disagg_pp_termination_handler is not None: + raise RuntimeError( + "Cannot terminate requests in cache transmission and release their KV cache resources when block reuse is enabled. Please consider increasing the KV cache size." + ) + + for retry_count in range(self.pp_scheduler_max_retry_count): + if self.scheduler.can_schedule(scheduled_batch_requests): + break + logger.debug( + f"Retrying to run first PP's schedule result ({retry_count + 1}/{self.pp_scheduler_max_retry_count})" + ) + + # Let cache transceiver finish at least one cache transmission and release requests' KV cache resources + self._check_disagg_ctx_cache_transfer_status(1) + self._check_kv_transfer_timeout() + self._terminate_disagg_ctx_finished_requests() + else: + raise RuntimeError( + f"Reach maximum PP retry count ({self.pp_scheduler_max_retry_count}) but still cannot run first PP's schedule result. Please consider increasing the KV cache size by setting `free_gpu_memory_fraction` to a larger value. Or you can set `TLLM_PP_SCHEDULER_MAX_RETRY_COUNT` to a larger value to allow more retries." + ) + def _executor_loop_pp(self): logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}") torch.cuda.set_device(self.device_id) @@ -799,6 +877,8 @@ class PyExecutor: profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() + + # Fetch new requests from request queue new_requests = self._fetch_and_activate_new_requests() if self.should_stop_processing: break @@ -816,11 +896,18 @@ class PyExecutor: self._pad_attention_dp_dummy_request() - scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule( + # Stage 0: first PP rank schedules requests and propagates the result to all other PP ranks. + scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._pp_schedule_and_propagate( ) + if not self.dist.is_first_pp_rank: + # Retry until current rank can run first PP's schedule result. + self._pp_retry_until_can_schedule(scheduled_batch) + # Run scheduler locally because scheduler may change llm requests' state. + self.scheduler.schedule_request(self.active_requests, + self.inflight_req_ids) + # For requests that are fitting disagg gen init, also prepare resources for KV cache manager if self.kv_cache_transceiver: - # For requests that are fitting disagg gen init, also prepare resources for KV cache manager self._prepare_disagg_gen_init( fitting_disagg_gen_init_requests) @@ -840,7 +927,6 @@ class PyExecutor: ) can_queue = self._can_queue(scheduled_batch) - if not can_queue: logger.debug( f"microbatch {microbatch_id} cannot be queued, skipping" @@ -928,6 +1014,7 @@ class PyExecutor: prev_microbatch_id = (microbatch_id + offset) % self.num_micro_batches previous_batch = self.micro_batches[prev_microbatch_id] + tag = PP_COMM_TAG_SAMPLE_STATE_BASE + prev_microbatch_id if previous_batch is not None: sample_state = previous_batch.sample_state if not self.dist.is_last_pp_rank: @@ -937,7 +1024,7 @@ class PyExecutor: with nvtx_range("recv_sample_state"): sample_state.host = recv_object_funct( src=self.dist.prev_pp_rank, - tag=prev_microbatch_id, + tag=tag, ) # Send tokens to next pp rank (w.r.t model forward direction) @@ -949,7 +1036,7 @@ class PyExecutor: prev_microbatch_id] = self.dist.isend_object( sample_state.host, dest=self.dist.next_pp_rank, - tag=prev_microbatch_id) + tag=tag) # Stage 3: Finalize previous batch that finished sample state communication # In last pp rank, stage 2 and 3 process different previous batches @@ -1746,24 +1833,26 @@ class PyExecutor: def _waiting_requests(self, context_requests: list[LlmRequest], generation_requests: list[LlmRequest]): - if not self.enable_batch_waiting: - return context_requests + """ + Return an empty list if scheduled requests fulfill the waiting conditions, otherwise return the original context requests. + Waiting conditions: + - The number of scheduled tokens (both context and generation) is smaller than `self.batch_wait_max_tokens_ratio * self.max_num_tokens` + - The number of waiting iterations is smaller than `self.batch_wait_timeout_iters`. + """ - waited_context_requests = [] - stop_waiting = False num_scheduled_ctx_tokens = sum( len(ctx_req.get_tokens(0)) for ctx_req in context_requests) num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens for gen_req in generation_requests) num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens - stop_waiting = self.batch_wait_iters_count >= self.batch_wait_timeout_iters or num_scheduled_tokens >= self.batch_wait_max_tokens_ratio * self.max_num_tokens - if stop_waiting: - waited_context_requests = context_requests - self.batch_wait_iters_count = 0 - else: + should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < self.batch_wait_max_tokens_ratio * self.max_num_tokens + if should_waiting: self.batch_wait_iters_count += 1 - return waited_context_requests + return [] + + self.batch_wait_iters_count = 0 + return context_requests @nvtx_range("_schedule") def _schedule(self): @@ -1775,10 +1864,11 @@ class PyExecutor: scheduler_output.context_requests, scheduler_output.generation_requests) - # if no generation requests, no need to wait, to avoid dead waiting - if not self.enable_attention_dp and self.enable_batch_waiting and len( - scheduler_output.context_requests) > 0 and len( - scheduler_output.generation_requests) > 0: + # If no generation requests, no need to wait, to avoid dead waiting + should_check_waiting = not self.enable_attention_dp and self.enable_batch_waiting and len( + scheduler_output.context_requests) > 0 and len( + scheduler_output.generation_requests) > 0 + if should_check_waiting: scheduled_context_requests = self._waiting_requests( scheduler_output.context_requests, scheduler_output.generation_requests) @@ -2403,7 +2493,10 @@ class PyExecutor: @nvtx_range("_terminate_disagg_ctx_finished_requests") def _terminate_disagg_ctx_finished_requests(self): - for request_id in list(self.ctx_in_transmission_requests.keys()): + # make a copy of the keys, since we are modifying the dictionary in the loop + in_transmission_requests_id = list( + self.ctx_in_transmission_requests.keys()) + for request_id in in_transmission_requests_id: request, block_id, counter = self.ctx_in_transmission_requests[ request_id] diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index c71c4596ed..2c1d8f916f 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from collections import namedtuple +from dataclasses import dataclass from typing import Optional, Tuple from strenum import StrEnum @@ -54,6 +55,70 @@ class RequestScheduler(ABC): # to be aligned with RequestScheduler::scheduleRequests in cpp/tensorrt_llm/batch_manager/requestScheduler.h raise NotImplementedError + @abstractmethod + def can_schedule(self, requests: RequestList) -> bool: + """ + Check if current rank can schedule the requests. + :param requests: list of requests to be scheduled + :return: True if current rank can schedule the requests, False otherwise + """ + raise NotImplementedError + + +@dataclass +class SerializableSchedulerOutput: + """ + Serializable version of SchedulerOutput, used for sending schedule result to other ranks. Need this class because LlmRequest is not serializable by pickle. + """ + context_requests: list[int] # request ids of context requests + generation_requests: list[int] # request ids of generation requests + paused_requests: list[int] # request ids of paused requests + fitting_disagg_gen_init_requests: list[ + int] # request ids of fitting disaggregated generation initialization requests + num_fitting_requests: int # number of fitting requests + + @classmethod + def from_scheduler_result( + cls, scheduled_requests: ScheduledRequests, + fitting_disagg_gen_init_requests: RequestList, + num_fitting_requests: int) -> "SerializableSchedulerOutput": + return cls(context_requests=[ + req.request_id for req in scheduled_requests.context_requests + ], + generation_requests=[ + req.request_id + for req in scheduled_requests.generation_requests + ], + paused_requests=[ + req.request_id + for req in scheduled_requests.paused_requests + ], + fitting_disagg_gen_init_requests=[ + req.request_id + for req in fitting_disagg_gen_init_requests + ], + num_fitting_requests=num_fitting_requests) + + def to_scheduler_result( + self, active_requests: RequestList + ) -> Tuple[ScheduledRequests, RequestList, int]: + id_to_request = {req.request_id: req for req in active_requests} + scheduled_requests = ScheduledRequests() + scheduled_requests.context_requests = [ + id_to_request[req_id] for req_id in self.context_requests + ] + scheduled_requests.generation_requests = [ + id_to_request[req_id] for req_id in self.generation_requests + ] + scheduled_requests.paused_requests = [ + id_to_request[req_id] for req_id in self.paused_requests + ] + fitting_disagg_gen_init_requests = [ + id_to_request[req_id] + for req_id in self.fitting_disagg_gen_init_requests + ] + return scheduled_requests, fitting_disagg_gen_init_requests, self.num_fitting_requests + class CapacityScheduler(ABC): @@ -216,3 +281,8 @@ class SimpleScheduler(RequestScheduler): list(generation_requests), list(paused_requests), list(fitting_disagg_gen_init_requests), len(fitting_requests)) + + def can_schedule(self, requests: RequestList) -> bool: + fitting_requests, _, _ = self.capacity_scheduler.schedule_request( + requests) + return len(fitting_requests) == len(requests) diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 7eb00943f6..36a5bc32e5 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -20,6 +20,7 @@ l0_a10: - unittest/_torch/modeling/test_modeling_mistral.py - unittest/_torch/modeling/test_modeling_pixtral.py - unittest/_torch/sampler/test_trtllm_sampler.py + - unittest/_torch/executor/test_scheduler_serializable_output.py # NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no # test list either). - unittest/_torch/models/checkpoints/hf/test_weight_loader.py diff --git a/tests/unittest/_torch/executor/test_scheduler_serializable_output.py b/tests/unittest/_torch/executor/test_scheduler_serializable_output.py new file mode 100644 index 0000000000..94fba12d7d --- /dev/null +++ b/tests/unittest/_torch/executor/test_scheduler_serializable_output.py @@ -0,0 +1,59 @@ +import pickle + +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, SamplingConfig +from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests, SerializableSchedulerOutput + + +def _make_request(request_id: int) -> LlmRequest: + return LlmRequest( + request_id=request_id, + max_new_tokens=5, + input_tokens=[request_id], + sampling_config=SamplingConfig(), + is_streaming=False, + ) + + +def _request_ids(requests): + return [req.request_id for req in requests] + + +def test_serializable_scheduler_output_round_trip(): + # Create all requests and put them in a pool + request_pool = {idx: _make_request(idx) for idx in range(1, 8)} + + # Create scheduler result: scheduled_requests, fitting_disagg_gen_init_requests, num_fitting_requests + scheduled_requests = ScheduledRequests() + scheduled_requests.context_requests = [request_pool[1], request_pool[2]] + scheduled_requests.generation_requests = [request_pool[3]] + scheduled_requests.paused_requests = [request_pool[4]] + fitting_disagg_gen_init_requests = [request_pool[5], request_pool[6]] + num_fitting_requests = 3 + + # Create serializable scheduler output from scheduler result + serializable_output = SerializableSchedulerOutput.from_scheduler_result( + scheduled_requests, fitting_disagg_gen_init_requests, num_fitting_requests + ) + + # Serialize and deserialize the serializable scheduler output + serialized_bytes = pickle.dumps(serializable_output) + restored_output: SerializableSchedulerOutput = pickle.loads(serialized_bytes) + + # Restore the scheduler result from the deserialized serializable scheduler output + active_requests = list(request_pool.values()) + restored_schedule, restored_fitting, restored_num_fitting = restored_output.to_scheduler_result( + active_requests + ) + + # Verify the restored scheduler result is correct + assert restored_num_fitting == num_fitting_requests + assert _request_ids(restored_schedule.context_requests) == _request_ids( + scheduled_requests.context_requests + ) + assert _request_ids(restored_schedule.generation_requests) == _request_ids( + scheduled_requests.generation_requests + ) + assert _request_ids(restored_schedule.paused_requests) == _request_ids( + scheduled_requests.paused_requests + ) + assert _request_ids(restored_fitting) == _request_ids(fitting_disagg_gen_init_requests)