mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5677746][fix] Use first PP rank's schedule result in other PP ranks to fix PP hang (#9659)
Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
This commit is contained in:
parent
84f7a7fd3c
commit
8d9baa4623
@ -247,7 +247,8 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
|
||||
{
|
||||
break;
|
||||
}
|
||||
else if (req->isGenerationInProgressState())
|
||||
|
||||
if (req->isGenerationInProgressState())
|
||||
{
|
||||
scheduledRequests.emplace_back(req);
|
||||
reservedBlocks.decrementReservedBlocks(*req);
|
||||
@ -296,7 +297,8 @@ std::tuple<RequestVector, RequestVector> GuaranteedNoEvictScheduler::impl(
|
||||
{
|
||||
break;
|
||||
}
|
||||
else if (req->isContextInitState() || req->isDisaggGenerationInitState())
|
||||
|
||||
if (req->isContextInitState() || req->isDisaggGenerationInitState())
|
||||
{
|
||||
bool enoughBlocks = reservedBlocks.enoughAvailableBlocks(*req);
|
||||
bool enoughCrossBlocks
|
||||
|
||||
@ -53,7 +53,8 @@ from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
|
||||
from .model_engine import ModelEngine
|
||||
from .resource_manager import ResourceManager
|
||||
from .sampler import Sampler, SampleState, SampleStateTensors
|
||||
from .scheduler import RequestScheduler, ScheduledRequests
|
||||
from .scheduler import (RequestScheduler, ScheduledRequests,
|
||||
SerializableSchedulerOutput)
|
||||
|
||||
# Environment variable to specify iteration ranges for profiling start/stop.
|
||||
# Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..."
|
||||
@ -65,6 +66,8 @@ PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
|
||||
|
||||
# Unique tag base to avoid collisions with token/logits comms
|
||||
TERMINATION_COMM_TAG_BASE = 20000
|
||||
PP_COMM_TAG_SCHEDULE_RESULT = 21000
|
||||
PP_COMM_TAG_SAMPLE_STATE_BASE = 21001
|
||||
|
||||
|
||||
@functools.cache
|
||||
@ -232,6 +235,10 @@ class PyExecutor:
|
||||
self.micro_batches: List[BatchStatePP
|
||||
| None] = [None] * self.num_micro_batches
|
||||
self.send_handles = [None] * self.num_micro_batches
|
||||
# schedule handle for PP to propagate the first PP rank's schedule result
|
||||
self.send_schedule_handler = None
|
||||
self.pp_scheduler_max_retry_count = int(
|
||||
os.environ.get("TLLM_PP_SCHEDULER_MAX_RETRY_COUNT", 10))
|
||||
|
||||
# Set of request IDs that are currently in flight across all micro batches.
|
||||
# The scheduler will avoid scheduling requests that are already in flight.
|
||||
@ -786,6 +793,77 @@ class PyExecutor:
|
||||
self.response_cv.notify_all()
|
||||
self.shutdown_event.set()
|
||||
|
||||
def _pp_schedule_and_propagate(self):
|
||||
"""The first PP rank schedules the requests and propagates the result to all other PP ranks."""
|
||||
|
||||
# The first PP rank schedules the requests, other ranks receive the schedule result from the previous PP rank.
|
||||
if self.dist.is_first_pp_rank:
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
|
||||
)
|
||||
serializable_schedule = SerializableSchedulerOutput.from_scheduler_result(
|
||||
scheduled_batch, fitting_disagg_gen_init_requests,
|
||||
num_fitting_reqs)
|
||||
else:
|
||||
with nvtx_range("recv_schedule_from_prev_pp"):
|
||||
serializable_schedule = self.dist.recv_object(
|
||||
self.dist.prev_pp_rank, PP_COMM_TAG_SCHEDULE_RESULT)
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = serializable_schedule.to_scheduler_result(
|
||||
self.active_requests)
|
||||
|
||||
# Propagate the schedule result to the next PP rank except the last PP rank.
|
||||
if not self.dist.is_last_pp_rank:
|
||||
if self.send_schedule_handler is not None:
|
||||
with nvtx_range("wait_send_schedule_handler"):
|
||||
self.send_schedule_handler.wait()
|
||||
with nvtx_range("send_schedule_to_next_pp"):
|
||||
self.send_schedule_handler = self.dist.isend_object(
|
||||
serializable_schedule, self.dist.next_pp_rank,
|
||||
PP_COMM_TAG_SCHEDULE_RESULT)
|
||||
return scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs
|
||||
|
||||
def _pp_retry_until_can_schedule(self, scheduled_batch):
|
||||
"""
|
||||
If current rank cannot run the scheduled batch, it will retry following steps until it has enough KV cache resources or reach maximum retry count:
|
||||
1. Wait for cache transceiver to finish at least one cache transmission.
|
||||
2. Terminate requests that have finished context cache transmission.
|
||||
3. Check if current rank has enough KV cache resources to run the scheduled batch.
|
||||
"""
|
||||
scheduled_batch_requests = scheduled_batch.all_requests()
|
||||
if self.scheduler.can_schedule(scheduled_batch_requests):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Cannot run first PP's schedule result due to limited KV cache resources. This may cause bubbles in the PP pipeline. Please consider increasing the KV cache size by setting `free_gpu_memory_fraction` to a larger value."
|
||||
)
|
||||
if self.kv_cache_transceiver is None:
|
||||
raise RuntimeError(
|
||||
"KV cache transceiver is not enabled, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected."
|
||||
)
|
||||
if not self.ctx_in_transmission_requests:
|
||||
raise RuntimeError(
|
||||
"No context cache transmission is in progress, but current rank cannot run first PP's schedule result due to limited KV cache resources. This is not expected."
|
||||
)
|
||||
if self.block_reuse_enabled and self._disagg_pp_termination_handler is not None:
|
||||
raise RuntimeError(
|
||||
"Cannot terminate requests in cache transmission and release their KV cache resources when block reuse is enabled. Please consider increasing the KV cache size."
|
||||
)
|
||||
|
||||
for retry_count in range(self.pp_scheduler_max_retry_count):
|
||||
if self.scheduler.can_schedule(scheduled_batch_requests):
|
||||
break
|
||||
logger.debug(
|
||||
f"Retrying to run first PP's schedule result ({retry_count + 1}/{self.pp_scheduler_max_retry_count})"
|
||||
)
|
||||
|
||||
# Let cache transceiver finish at least one cache transmission and release requests' KV cache resources
|
||||
self._check_disagg_ctx_cache_transfer_status(1)
|
||||
self._check_kv_transfer_timeout()
|
||||
self._terminate_disagg_ctx_finished_requests()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Reach maximum PP retry count ({self.pp_scheduler_max_retry_count}) but still cannot run first PP's schedule result. Please consider increasing the KV cache size by setting `free_gpu_memory_fraction` to a larger value. Or you can set `TLLM_PP_SCHEDULER_MAX_RETRY_COUNT` to a larger value to allow more retries."
|
||||
)
|
||||
|
||||
def _executor_loop_pp(self):
|
||||
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
|
||||
torch.cuda.set_device(self.device_id)
|
||||
@ -799,6 +877,8 @@ class PyExecutor:
|
||||
profile_step()
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_start_time = time.time()
|
||||
|
||||
# Fetch new requests from request queue
|
||||
new_requests = self._fetch_and_activate_new_requests()
|
||||
if self.should_stop_processing:
|
||||
break
|
||||
@ -816,11 +896,18 @@ class PyExecutor:
|
||||
|
||||
self._pad_attention_dp_dummy_request()
|
||||
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
|
||||
# Stage 0: first PP rank schedules requests and propagates the result to all other PP ranks.
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._pp_schedule_and_propagate(
|
||||
)
|
||||
if not self.dist.is_first_pp_rank:
|
||||
# Retry until current rank can run first PP's schedule result.
|
||||
self._pp_retry_until_can_schedule(scheduled_batch)
|
||||
# Run scheduler locally because scheduler may change llm requests' state.
|
||||
self.scheduler.schedule_request(self.active_requests,
|
||||
self.inflight_req_ids)
|
||||
|
||||
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
|
||||
if self.kv_cache_transceiver:
|
||||
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
|
||||
self._prepare_disagg_gen_init(
|
||||
fitting_disagg_gen_init_requests)
|
||||
|
||||
@ -840,7 +927,6 @@ class PyExecutor:
|
||||
)
|
||||
|
||||
can_queue = self._can_queue(scheduled_batch)
|
||||
|
||||
if not can_queue:
|
||||
logger.debug(
|
||||
f"microbatch {microbatch_id} cannot be queued, skipping"
|
||||
@ -928,6 +1014,7 @@ class PyExecutor:
|
||||
prev_microbatch_id = (microbatch_id +
|
||||
offset) % self.num_micro_batches
|
||||
previous_batch = self.micro_batches[prev_microbatch_id]
|
||||
tag = PP_COMM_TAG_SAMPLE_STATE_BASE + prev_microbatch_id
|
||||
if previous_batch is not None:
|
||||
sample_state = previous_batch.sample_state
|
||||
if not self.dist.is_last_pp_rank:
|
||||
@ -937,7 +1024,7 @@ class PyExecutor:
|
||||
with nvtx_range("recv_sample_state"):
|
||||
sample_state.host = recv_object_funct(
|
||||
src=self.dist.prev_pp_rank,
|
||||
tag=prev_microbatch_id,
|
||||
tag=tag,
|
||||
)
|
||||
|
||||
# Send tokens to next pp rank (w.r.t model forward direction)
|
||||
@ -949,7 +1036,7 @@ class PyExecutor:
|
||||
prev_microbatch_id] = self.dist.isend_object(
|
||||
sample_state.host,
|
||||
dest=self.dist.next_pp_rank,
|
||||
tag=prev_microbatch_id)
|
||||
tag=tag)
|
||||
|
||||
# Stage 3: Finalize previous batch that finished sample state communication
|
||||
# In last pp rank, stage 2 and 3 process different previous batches
|
||||
@ -1746,24 +1833,26 @@ class PyExecutor:
|
||||
|
||||
def _waiting_requests(self, context_requests: list[LlmRequest],
|
||||
generation_requests: list[LlmRequest]):
|
||||
if not self.enable_batch_waiting:
|
||||
return context_requests
|
||||
"""
|
||||
Return an empty list if scheduled requests fulfill the waiting conditions, otherwise return the original context requests.
|
||||
Waiting conditions:
|
||||
- The number of scheduled tokens (both context and generation) is smaller than `self.batch_wait_max_tokens_ratio * self.max_num_tokens`
|
||||
- The number of waiting iterations is smaller than `self.batch_wait_timeout_iters`.
|
||||
"""
|
||||
|
||||
waited_context_requests = []
|
||||
stop_waiting = False
|
||||
num_scheduled_ctx_tokens = sum(
|
||||
len(ctx_req.get_tokens(0)) for ctx_req in context_requests)
|
||||
num_scheduled_gen_tokens = sum(1 + gen_req.num_draft_tokens
|
||||
for gen_req in generation_requests)
|
||||
num_scheduled_tokens = num_scheduled_ctx_tokens + num_scheduled_gen_tokens
|
||||
|
||||
stop_waiting = self.batch_wait_iters_count >= self.batch_wait_timeout_iters or num_scheduled_tokens >= self.batch_wait_max_tokens_ratio * self.max_num_tokens
|
||||
if stop_waiting:
|
||||
waited_context_requests = context_requests
|
||||
self.batch_wait_iters_count = 0
|
||||
else:
|
||||
should_waiting = self.batch_wait_iters_count < self.batch_wait_timeout_iters and num_scheduled_tokens < self.batch_wait_max_tokens_ratio * self.max_num_tokens
|
||||
if should_waiting:
|
||||
self.batch_wait_iters_count += 1
|
||||
return waited_context_requests
|
||||
return []
|
||||
|
||||
self.batch_wait_iters_count = 0
|
||||
return context_requests
|
||||
|
||||
@nvtx_range("_schedule")
|
||||
def _schedule(self):
|
||||
@ -1775,10 +1864,11 @@ class PyExecutor:
|
||||
scheduler_output.context_requests,
|
||||
scheduler_output.generation_requests)
|
||||
|
||||
# if no generation requests, no need to wait, to avoid dead waiting
|
||||
if not self.enable_attention_dp and self.enable_batch_waiting and len(
|
||||
scheduler_output.context_requests) > 0 and len(
|
||||
scheduler_output.generation_requests) > 0:
|
||||
# If no generation requests, no need to wait, to avoid dead waiting
|
||||
should_check_waiting = not self.enable_attention_dp and self.enable_batch_waiting and len(
|
||||
scheduler_output.context_requests) > 0 and len(
|
||||
scheduler_output.generation_requests) > 0
|
||||
if should_check_waiting:
|
||||
scheduled_context_requests = self._waiting_requests(
|
||||
scheduler_output.context_requests,
|
||||
scheduler_output.generation_requests)
|
||||
@ -2403,7 +2493,10 @@ class PyExecutor:
|
||||
|
||||
@nvtx_range("_terminate_disagg_ctx_finished_requests")
|
||||
def _terminate_disagg_ctx_finished_requests(self):
|
||||
for request_id in list(self.ctx_in_transmission_requests.keys()):
|
||||
# make a copy of the keys, since we are modifying the dictionary in the loop
|
||||
in_transmission_requests_id = list(
|
||||
self.ctx_in_transmission_requests.keys())
|
||||
for request_id in in_transmission_requests_id:
|
||||
request, block_id, counter = self.ctx_in_transmission_requests[
|
||||
request_id]
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from strenum import StrEnum
|
||||
@ -54,6 +55,70 @@ class RequestScheduler(ABC):
|
||||
# to be aligned with RequestScheduler::scheduleRequests in cpp/tensorrt_llm/batch_manager/requestScheduler.h
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def can_schedule(self, requests: RequestList) -> bool:
|
||||
"""
|
||||
Check if current rank can schedule the requests.
|
||||
:param requests: list of requests to be scheduled
|
||||
:return: True if current rank can schedule the requests, False otherwise
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class SerializableSchedulerOutput:
|
||||
"""
|
||||
Serializable version of SchedulerOutput, used for sending schedule result to other ranks. Need this class because LlmRequest is not serializable by pickle.
|
||||
"""
|
||||
context_requests: list[int] # request ids of context requests
|
||||
generation_requests: list[int] # request ids of generation requests
|
||||
paused_requests: list[int] # request ids of paused requests
|
||||
fitting_disagg_gen_init_requests: list[
|
||||
int] # request ids of fitting disaggregated generation initialization requests
|
||||
num_fitting_requests: int # number of fitting requests
|
||||
|
||||
@classmethod
|
||||
def from_scheduler_result(
|
||||
cls, scheduled_requests: ScheduledRequests,
|
||||
fitting_disagg_gen_init_requests: RequestList,
|
||||
num_fitting_requests: int) -> "SerializableSchedulerOutput":
|
||||
return cls(context_requests=[
|
||||
req.request_id for req in scheduled_requests.context_requests
|
||||
],
|
||||
generation_requests=[
|
||||
req.request_id
|
||||
for req in scheduled_requests.generation_requests
|
||||
],
|
||||
paused_requests=[
|
||||
req.request_id
|
||||
for req in scheduled_requests.paused_requests
|
||||
],
|
||||
fitting_disagg_gen_init_requests=[
|
||||
req.request_id
|
||||
for req in fitting_disagg_gen_init_requests
|
||||
],
|
||||
num_fitting_requests=num_fitting_requests)
|
||||
|
||||
def to_scheduler_result(
|
||||
self, active_requests: RequestList
|
||||
) -> Tuple[ScheduledRequests, RequestList, int]:
|
||||
id_to_request = {req.request_id: req for req in active_requests}
|
||||
scheduled_requests = ScheduledRequests()
|
||||
scheduled_requests.context_requests = [
|
||||
id_to_request[req_id] for req_id in self.context_requests
|
||||
]
|
||||
scheduled_requests.generation_requests = [
|
||||
id_to_request[req_id] for req_id in self.generation_requests
|
||||
]
|
||||
scheduled_requests.paused_requests = [
|
||||
id_to_request[req_id] for req_id in self.paused_requests
|
||||
]
|
||||
fitting_disagg_gen_init_requests = [
|
||||
id_to_request[req_id]
|
||||
for req_id in self.fitting_disagg_gen_init_requests
|
||||
]
|
||||
return scheduled_requests, fitting_disagg_gen_init_requests, self.num_fitting_requests
|
||||
|
||||
|
||||
class CapacityScheduler(ABC):
|
||||
|
||||
@ -216,3 +281,8 @@ class SimpleScheduler(RequestScheduler):
|
||||
list(generation_requests), list(paused_requests),
|
||||
list(fitting_disagg_gen_init_requests),
|
||||
len(fitting_requests))
|
||||
|
||||
def can_schedule(self, requests: RequestList) -> bool:
|
||||
fitting_requests, _, _ = self.capacity_scheduler.schedule_request(
|
||||
requests)
|
||||
return len(fitting_requests) == len(requests)
|
||||
|
||||
@ -20,6 +20,7 @@ l0_a10:
|
||||
- unittest/_torch/modeling/test_modeling_mistral.py
|
||||
- unittest/_torch/modeling/test_modeling_pixtral.py
|
||||
- unittest/_torch/sampler/test_trtllm_sampler.py
|
||||
- unittest/_torch/executor/test_scheduler_serializable_output.py
|
||||
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
|
||||
# test list either).
|
||||
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
import pickle
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, SamplingConfig
|
||||
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests, SerializableSchedulerOutput
|
||||
|
||||
|
||||
def _make_request(request_id: int) -> LlmRequest:
|
||||
return LlmRequest(
|
||||
request_id=request_id,
|
||||
max_new_tokens=5,
|
||||
input_tokens=[request_id],
|
||||
sampling_config=SamplingConfig(),
|
||||
is_streaming=False,
|
||||
)
|
||||
|
||||
|
||||
def _request_ids(requests):
|
||||
return [req.request_id for req in requests]
|
||||
|
||||
|
||||
def test_serializable_scheduler_output_round_trip():
|
||||
# Create all requests and put them in a pool
|
||||
request_pool = {idx: _make_request(idx) for idx in range(1, 8)}
|
||||
|
||||
# Create scheduler result: scheduled_requests, fitting_disagg_gen_init_requests, num_fitting_requests
|
||||
scheduled_requests = ScheduledRequests()
|
||||
scheduled_requests.context_requests = [request_pool[1], request_pool[2]]
|
||||
scheduled_requests.generation_requests = [request_pool[3]]
|
||||
scheduled_requests.paused_requests = [request_pool[4]]
|
||||
fitting_disagg_gen_init_requests = [request_pool[5], request_pool[6]]
|
||||
num_fitting_requests = 3
|
||||
|
||||
# Create serializable scheduler output from scheduler result
|
||||
serializable_output = SerializableSchedulerOutput.from_scheduler_result(
|
||||
scheduled_requests, fitting_disagg_gen_init_requests, num_fitting_requests
|
||||
)
|
||||
|
||||
# Serialize and deserialize the serializable scheduler output
|
||||
serialized_bytes = pickle.dumps(serializable_output)
|
||||
restored_output: SerializableSchedulerOutput = pickle.loads(serialized_bytes)
|
||||
|
||||
# Restore the scheduler result from the deserialized serializable scheduler output
|
||||
active_requests = list(request_pool.values())
|
||||
restored_schedule, restored_fitting, restored_num_fitting = restored_output.to_scheduler_result(
|
||||
active_requests
|
||||
)
|
||||
|
||||
# Verify the restored scheduler result is correct
|
||||
assert restored_num_fitting == num_fitting_requests
|
||||
assert _request_ids(restored_schedule.context_requests) == _request_ids(
|
||||
scheduled_requests.context_requests
|
||||
)
|
||||
assert _request_ids(restored_schedule.generation_requests) == _request_ids(
|
||||
scheduled_requests.generation_requests
|
||||
)
|
||||
assert _request_ids(restored_schedule.paused_requests) == _request_ids(
|
||||
scheduled_requests.paused_requests
|
||||
)
|
||||
assert _request_ids(restored_fitting) == _request_ids(fitting_disagg_gen_init_requests)
|
||||
Loading…
Reference in New Issue
Block a user