[https://nvbugs/5677746][fix] Use first PP rank's schedule result in other PP ranks to fix PP hang (#9659)

Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
This commit is contained in:
Jiagan Cheng 2025-12-09 10:43:52 +08:00 committed by GitHub
parent d6f961d3fe
commit 4a3a66b124
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 248 additions and 23 deletions

View File

@ -247,7 +247,8 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
{
break;
}
else if (req->isGenerationInProgressState())
if (req->isGenerationInProgressState())
{
scheduledRequests.emplace_back(req);
reservedBlocks.decrementReservedBlocks(*req);
@ -296,7 +297,8 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
{
break;
}
else if (req->isContextInitState() || req->isDisaggGenerationInitState())
if (req->isContextInitState() || req->isDisaggGenerationInitState())
{
bool enoughBlocks = reservedBlocks.enoughAvailableBlocks(*req);
bool enoughCrossBlocks

View File

@ -53,7 +53,8 @@ from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
from .model_engine import ModelEngine
from .resource_manager import ResourceManager
from .sampler import Sampler, SampleState, SampleStateTensors
from .scheduler import RequestScheduler, ScheduledRequests
from .scheduler import (RequestScheduler, ScheduledRequests,
SerializableSchedulerOutput)
# Environment variable to specify iteration ranges for profiling start/stop.
# Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..."
@ -65,6 +66,8 @@ PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
# Unique tag base to avoid collisions with token/logits comms
TERMINATION_COMM_TAG_BASE = 20000
PP_COMM_TAG_SCHEDULE_RESULT = 21000
PP_COMM_TAG_SAMPLE_STATE_BASE = 21001
@functools.cache
@ -232,6 +235,10 @@ class PyExecutor:
self.micro_batches: List[BatchStatePP
| None] = [None] * self.num_micro_batches
self.send_handles = [None] * self.num_micro_batches
# schedule handle for PP to propagate the first PP rank's schedule result
self.send_schedule_handler = None
self.pp_scheduler_max_retry_count = int(
os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10))
# Set of request IDs that are currently in flight across all micro batches.
# The scheduler will avoid scheduling requests that are already in flight.
@ -786,6 +793,77 @@ class PyExecutor:
self.response_cv.notify_all()
self.shutdown_event.set()
def _pp_schedule_and_propagate(self):
"""The first PP rank schedules the requests and propagates the result to all other PP ranks."""
# The first PP rank schedules the requests, other ranks receive the schedule result from the previous PP rank.
if self.dist.is_first_pp_rank:
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
serializable_schedule = SerializableSchedulerOutput.from_scheduler_result(
scheduled_batch, fitting_disagg_gen_init_requests,
num_fitting_reqs)
else:
with nvtx_range("recv_schedule_from_prev_pp"):
serializable_schedule = self.dist.recv_object(
self.dist.prev_pp_rank, PP_COMM_TAG_SCHEDULE_RESULT)
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result(
self.active_requests)
# Propagate the schedule result to the next PP rank except the last PP rank.
if not self.dist.is_last_pp_rank:
if self.send_schedule_handler is not None:
with nvtx_range("wait_send_schedule_handler"):
self.send_schedule_handler.wait()
with nvtx_range("send_schedule_to_next_pp"):
self.send_schedule_handler = self.dist.isend_object(
serializable_schedule, self.dist.next_pp_rank,
PP_COMM_TAG_SCHEDULE_RESULT)
return scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs
def _pp_retry_until_can_schedule(self, scheduled_batch):
"""
If current rank cannot run the scheduled batch, it will retry following steps until it has enough KV cache resources or reach maximum retry count:
1. Wait for cache transceiver to finish at least one cache transmission.
2. Terminate requests that have finished context cache transmission.
3. Check if current rank has enough KV cache resources to run the scheduled batch.
"""
scheduled_batch_requests = scheduled_batch.all_requests()
if self.scheduler.can_schedule(scheduled_batch_requests):
return
logger.warning(
"Cannot run first PP's schedule result due to limited KV cache resources. This may cause bubbles in the PP pipeline. Please consider increasing the KV cache size by setting `free_gpu_memory_fraction` to a larger value."
)
if self.kv_cache_transceiver is None:
raise RuntimeError(
"KV cache transceiver is not enabled, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected."
)
if not self.ctx_in_transmission_requests:
raise RuntimeError(
"No context cache transmission is in progress, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected."
)
if self.block_reuse_enabled and self._disagg_pp_termination_handler is not None:
raise RuntimeError(
"Cannot terminate requests in cache transmission and release their KV cache resources when block reuse is enabled. Please consider increasing the KV cache size."
)
for retry_count in range(self.pp_scheduler_max_retry_count):
if self.scheduler.can_schedule(scheduled_batch_requests):
break
logger.debug(
f"Retrying to run first PP's schedule result ({retry_count + 1}/{self.pp_scheduler_max_retry_count})"
)
# Let cache transceiver finish at least one cache transmission and release requests' KV cache resources
self._check_disagg_ctx_cache_transfer_status(1)
self._check_kv_transfer_timeout()
self._terminate_disagg_ctx_finished_requests()
else:
raise RuntimeError(
f"Reach maximum PP retry count ({self.pp_scheduler_max_retry_count}) but still cannot run first PP's schedule result. Please consider increasing the KV cache size by setting `free_gpu_memory_fraction` to a larger value. Or you can set `TLLM_PP_SCHEDULER_MAX_RETRY_COUNT` to a larger value to allow more retries."
)
def _executor_loop_pp(self):
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
torch.cuda.set_device(self.device_id)
@ -799,6 +877,8 @@ class PyExecutor:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
# Fetch new requests from request queue
new_requests = self._fetch_and_activate_new_requests()
if self.should_stop_processing:
break
@ -816,11 +896,18 @@ class PyExecutor:
self._pad_attention_dp_dummy_request()
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
# Stage 0: first PP rank schedules requests and propagates the result to all other PP ranks.
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._pp_schedule_and_propagate(
)
if not self.dist.is_first_pp_rank:
# Retry until current rank can run first PP's schedule result.
self._pp_retry_until_can_schedule(scheduled_batch)
# Run scheduler locally because scheduler may change llm requests' state.
self.scheduler.schedule_request(self.active_requests,
self.inflight_req_ids)
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
if self.kv_cache_transceiver:
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
self._prepare_disagg_gen_init(
fitting_disagg_gen_init_requests)
@ -840,7 +927,6 @@ class PyExecutor:
)
can_queue = self._can_queue(scheduled_batch)
if not can_queue:
logger.debug(
f"microbatch {microbatch_id} cannot be queued, skipping"
@ -928,6 +1014,7 @@ class PyExecutor:
prev_microbatch_id = (microbatch_id +
offset) % self.num_micro_batches
previous_batch = self.micro_batches[prev_microbatch_id]
tag = PP_COMM_TAG_SAMPLE_STATE_BASE + prev_microbatch_id
if previous_batch is not None:
sample_state = previous_batch.sample_state
if not self.dist.is_last_pp_rank:
@ -937,7 +1024,7 @@ class PyExecutor:
with nvtx_range("recv_sample_state"):
sample_state.host = recv_object_funct(
src=self.dist.prev_pp_rank,
tag=prev_microbatch_id,
tag=tag,
)
# Send tokens to next pp rank (w.r.t model forward direction)
@ -949,7 +1036,7 @@ class PyExecutor:
prev_microbatch_id] = self.dist.isend_object(
sample_state.host,
dest=self.dist.next_pp_rank,
tag=prev_microbatch_id)
tag=tag)
# Stage 3: Finalize previous batch that finished sample state communication
# In last pp rank, stage 2 and 3 process different previous batches
@ -1746,24 +1833,26 @@ class PyExecutor:
def _waiting_requests(self, context_requests: list[LlmRequest],
generation_requests: list[LlmRequest]):
if not self.enable_batch_waiting:
return context_requests
"""
Return an empty list if scheduled requests fulfill the waiting conditions, otherwise return the original context requests.
Waiting conditions:
- The number of scheduled tokens (both context and generation) is smaller than `self.batch_wait_max_tokens_ratio * self.max_num_tokens`
- The number of waiting iterations is smaller than `self.batch_wait_timeout_iters`.
"""
waited_context_requests = []
stop_waiting = False
num_scheduled_ctx_tokens = sum(
len(ctx_req.get_tokens(0)) for ctx_req in context_requests)
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
for gen_req in generation_requests)
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens
stop_waiting = self.batch_wait_iters_count >= self.batch_wait_timeout_iters or num_scheduled_tokens >= self.batch_wait_max_tokens_ratio * self.max_num_tokens
if stop_waiting:
waited_context_requests = context_requests
self.batch_wait_iters_count = 0
else:
should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < self.batch_wait_max_tokens_ratio * self.max_num_tokens
if should_waiting:
self.batch_wait_iters_count += 1
return waited_context_requests
return []
self.batch_wait_iters_count = 0
return context_requests
@nvtx_range("_schedule")
def _schedule(self):
@ -1775,10 +1864,11 @@ class PyExecutor:
scheduler_output.context_requests,
scheduler_output.generation_requests)
# if no generation requests, no need to wait, to avoid dead waiting
if not self.enable_attention_dp and self.enable_batch_waiting and len(
scheduler_output.context_requests) > 0 and len(
scheduler_output.generation_requests) > 0:
# If no generation requests, no need to wait, to avoid dead waiting
should_check_waiting = not self.enable_attention_dp and self.enable_batch_waiting and len(
scheduler_output.context_requests) > 0 and len(
scheduler_output.generation_requests) > 0
if should_check_waiting:
scheduled_context_requests = self._waiting_requests(
scheduler_output.context_requests,
scheduler_output.generation_requests)
@ -2408,7 +2498,10 @@ class PyExecutor:
@nvtx_range("_terminate_disagg_ctx_finished_requests")
def _terminate_disagg_ctx_finished_requests(self):
for request_id in list(self.ctx_in_transmission_requests.keys()):
# make a copy of the keys, since we are modifying the dictionary in the loop
in_transmission_requests_id = list(
self.ctx_in_transmission_requests.keys())
for request_id in in_transmission_requests_id:
request, block_id, counter = self.ctx_in_transmission_requests[
request_id]

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass
from typing import Optional, Tuple
from strenum import StrEnum
@ -54,6 +55,70 @@ class RequestScheduler(ABC):
# to be aligned with RequestScheduler::scheduleRequests in cpp/tensorrt_llm/batch_manager/requestScheduler.h
raise NotImplementedError
@abstractmethod
def can_schedule(self, requests: RequestList) -> bool:
"""
Check if current rank can schedule the requests.
:param requests: list of requests to be scheduled
:return: True if current rank can schedule the requests, False otherwise
"""
raise NotImplementedError
@dataclass
class SerializableSchedulerOutput:
"""
Serializable version of SchedulerOutput, used for sending schedule result to other ranks. Need this class because LlmRequest is not serializable by pickle.
"""
context_requests: list[int] # request ids of context requests
generation_requests: list[int] # request ids of generation requests
paused_requests: list[int] # request ids of paused requests
fitting_disagg_gen_init_requests: list[
int] # request ids of fitting disaggregated generation initialization requests
num_fitting_requests: int # number of fitting requests
@classmethod
def from_scheduler_result(
cls, scheduled_requests: ScheduledRequests,
fitting_disagg_gen_init_requests: RequestList,
num_fitting_requests: int) -> "SerializableSchedulerOutput":
return cls(context_requests=[
req.request_id for req in scheduled_requests.context_requests
],
generation_requests=[
req.request_id
for req in scheduled_requests.generation_requests
],
paused_requests=[
req.request_id
for req in scheduled_requests.paused_requests
],
fitting_disagg_gen_init_requests=[
req.request_id
for req in fitting_disagg_gen_init_requests
],
num_fitting_requests=num_fitting_requests)
def to_scheduler_result(
self, active_requests: RequestList
) -> Tuple[ScheduledRequests, RequestList, int]:
id_to_request = {req.request_id: req for req in active_requests}
scheduled_requests = ScheduledRequests()
scheduled_requests.context_requests = [
id_to_request[req_id] for req_id in self.context_requests
]
scheduled_requests.generation_requests = [
id_to_request[req_id] for req_id in self.generation_requests
]
scheduled_requests.paused_requests = [
id_to_request[req_id] for req_id in self.paused_requests
]
fitting_disagg_gen_init_requests = [
id_to_request[req_id]
for req_id in self.fitting_disagg_gen_init_requests
]
return scheduled_requests, fitting_disagg_gen_init_requests, self.num_fitting_requests
class CapacityScheduler(ABC):
@ -216,3 +281,8 @@ class SimpleScheduler(RequestScheduler):
list(generation_requests), list(paused_requests),
list(fitting_disagg_gen_init_requests),
len(fitting_requests))
def can_schedule(self, requests: RequestList) -> bool:
fitting_requests, _, _ = self.capacity_scheduler.schedule_request(
requests)
return len(fitting_requests) == len(requests)

View File

@ -20,6 +20,7 @@ l0_a10:
- unittest/_torch/modeling/test_modeling_mistral.py
- unittest/_torch/modeling/test_modeling_pixtral.py
- unittest/_torch/sampler/test_trtllm_sampler.py
- unittest/_torch/executor/test_scheduler_serializable_output.py
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
# test list either).
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py

View File

@ -0,0 +1,59 @@
import pickle
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, SamplingConfig
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests, SerializableSchedulerOutput
def _make_request(request_id: int) -> LlmRequest:
return LlmRequest(
request_id=request_id,
max_new_tokens=5,
input_tokens=[request_id],
sampling_config=SamplingConfig(),
is_streaming=False,
)
def _request_ids(requests):
return [req.request_id for req in requests]
def test_serializable_scheduler_output_round_trip():
# Create all requests and put them in a pool
request_pool = {idx: _make_request(idx) for idx in range(1, 8)}
# Create scheduler result: scheduled_requests, fitting_disagg_gen_init_requests, num_fitting_requests
scheduled_requests = ScheduledRequests()
scheduled_requests.context_requests = [request_pool[1], request_pool[2]]
scheduled_requests.generation_requests = [request_pool[3]]
scheduled_requests.paused_requests = [request_pool[4]]
fitting_disagg_gen_init_requests = [request_pool[5], request_pool[6]]
num_fitting_requests = 3
# Create serializable scheduler output from scheduler result
serializable_output = SerializableSchedulerOutput.from_scheduler_result(
scheduled_requests, fitting_disagg_gen_init_requests, num_fitting_requests
)
# Serialize and deserialize the serializable scheduler output
serialized_bytes = pickle.dumps(serializable_output)
restored_output: SerializableSchedulerOutput = pickle.loads(serialized_bytes)
# Restore the scheduler result from the deserialized serializable scheduler output
active_requests = list(request_pool.values())
restored_schedule, restored_fitting, restored_num_fitting = restored_output.to_scheduler_result(
active_requests
)
# Verify the restored scheduler result is correct
assert restored_num_fitting == num_fitting_requests
assert _request_ids(restored_schedule.context_requests) == _request_ids(
scheduled_requests.context_requests
)
assert _request_ids(restored_schedule.generation_requests) == _request_ids(
scheduled_requests.generation_requests
)
assert _request_ids(restored_schedule.paused_requests) == _request_ids(
scheduled_requests.paused_requests
)
assert _request_ids(restored_fitting) == _request_ids(fitting_disagg_gen_init_requests)