[None][feat] Add support of scheduling attention dp request (#6246)

Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
This commit is contained in:
Shunkangz 2025-08-02 08:38:01 +08:00 committed by GitHub
parent 31802de0b0
commit 67a3fd858b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 701 additions and 161 deletions

View File

@ -87,27 +87,68 @@ class ExecutorRequestQueue:
self,
waiting_queue: deque[RequestQueueItem],
max_req_count: int,
enable_attention_dp: bool,
all_ranks_num_active_requests: Optional[List[int]] = None,
) -> List[RequestQueueItem]:
"""Safely extracts up to max_req_count items from a deque.
"""
Args:
waiting_queue: The queue to pop items from.
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
enable_attention_dp: Whether to enable attention DP scheduling.
all_ranks_num_active_requests: Number of active requests for each rank.
Returns:
List of retrieved items (may be shorter than max_req_count if queue empties first).
List of requests that can be processed.
"""
# Edge case handling
if max_req_count <= 0: # Handles negative/zero counts
if max_req_count <= 0:
return []
items = []
req_count = 0
items = []
pending_requests = []
# Track the request with strict requirements
scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy(
) if enable_attention_dp else None
while req_count < max_req_count and waiting_queue:
items.append(waiting_queue.popleft())
req_count += 1
req_item = waiting_queue.popleft()
can_process = self._can_process_attention_dp_request(
req_item, scheduling_all_ranks_num_active_requests
) if enable_attention_dp else True
if can_process:
items.append(req_item)
req_count += 1
else:
pending_requests.append(req_item)
# Put the pending requests back to the waiting queue
# All ranks should have the same waiting queue
waiting_queue.extendleft(reversed(pending_requests))
return items
def _can_process_attention_dp_request(
self, req_item: RequestQueueItem,
all_ranks_num_active_requests: List[int]) -> bool:
"""Return True if the request can be processed immediately, else False."""
scheduling_params = getattr(req_item.request, 'py_scheduling_params',
None)
if scheduling_params is None:
return True
target_dp_rank = scheduling_params.attention_dp_rank
if target_dp_rank is None or scheduling_params.attention_dp_relax:
return True
if all_ranks_num_active_requests[
target_dp_rank] < self.max_num_active_requests:
all_ranks_num_active_requests[target_dp_rank] += 1
return True
return False
def enqueue_requests(self, requests: List[ExecutorRequest]):
req_ids = []
try:
@ -166,8 +207,12 @@ class ExecutorRequestQueue:
return can_enqueue and self.dist.rank == 0
def _fetch_and_process_requests(
self, total_num_active_requests: int,
total_max_num_active_requests: int) -> List[RequestQueueItem]:
self,
total_num_active_requests: int,
total_max_num_active_requests: int,
enable_attention_dp: bool,
all_ranks_num_active_requests: Optional[List[int]] = None
) -> List[RequestQueueItem]:
"""Common logic for fetching and processing requests from the queue."""
# Calculate timeout
timeout = None if (total_num_active_requests == 0) and len(
@ -195,7 +240,8 @@ class ExecutorRequestQueue:
new_requests = self._get_from_waiting_queue(
self.waiting_queue,
total_max_num_active_requests - total_num_active_requests)
total_max_num_active_requests - total_num_active_requests,
enable_attention_dp, all_ranks_num_active_requests)
# Update performance metrics
if self.enable_iter_perf_stats and self.dist.rank == 0:
@ -218,9 +264,11 @@ class ExecutorRequestQueue:
total_num_active_requests = num_active_requests
total_max_num_active_requests = self.max_num_active_requests
# Use common request fetching logic
# fetch and process requests into waiting queue
new_requests = self._fetch_and_process_requests(
total_num_active_requests, total_max_num_active_requests)
total_num_active_requests,
total_max_num_active_requests,
enable_attention_dp=False)
# Merge requests and add to active list
merged_requests = self._merge_requests(new_requests)
@ -238,20 +286,17 @@ class ExecutorRequestQueue:
total_num_active_requests = sum(all_ranks_num_active_requests)
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
# Use common request fetching logic
# fetch and process requests into waiting queue
new_requests = self._fetch_and_process_requests(
total_num_active_requests, total_max_num_active_requests)
total_num_active_requests,
total_max_num_active_requests,
enable_attention_dp=True,
all_ranks_num_active_requests=all_ranks_num_active_requests)
# Balance requests across ranks
num_new_requests_all_ranks = len(new_requests)
self.expected_num_active_requests = max(
(total_num_active_requests + num_new_requests_all_ranks +
self.dist.tp_size - 1) // self.dist.tp_size,
max(all_ranks_num_active_requests),
)
new_requests_cur_rank = self._balance_requests_across_ranks(
# Schedule attention dp requests
all_ranks_new_requests = self._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]
# Update performance metrics
if self.enable_iter_perf_stats and self.start_times:
@ -259,13 +304,66 @@ class ExecutorRequestQueue:
new_requests_cur_rank)
# Update counters
self.num_fetch_requests += num_new_requests_all_ranks
self.num_fetch_requests += len(new_requests)
self.num_fetch_requests_cur_rank += len(new_requests_cur_rank)
# Merge requests and add to active list
new_requests_cur_rank = self._merge_requests(new_requests_cur_rank)
return new_requests_cur_rank
def _schedule_attention_dp_requests(
self, new_requests: List[RequestQueueItem],
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
"""Schedule attention dp requests."""
# Map from ranks to new requests
all_ranks_new_requests = {
tp_rank: []
for tp_rank in range(self.dist.tp_size)
}
# Prioritize the requests that are not in relax mode
def get_relax_value(req_item):
scheduling_params = getattr(req_item.request,
'py_scheduling_params', None)
if scheduling_params is None:
return True
return scheduling_params.attention_dp_relax
new_requests = sorted(new_requests, key=get_relax_value, reverse=True)
# Try to put the requests to the target dp rank until the max_num_active_requests is reached
remaining_unscheduled = []
for req_item in new_requests:
scheduled = False
scheduling_params = getattr(req_item.request,
'py_scheduling_params', None)
if scheduling_params is not None:
target_dp_rank = scheduling_params.attention_dp_rank
if target_dp_rank is not None and all_ranks_num_active_requests[
target_dp_rank] < self.max_num_active_requests:
all_ranks_num_active_requests[target_dp_rank] += 1
scheduled = True
all_ranks_new_requests[target_dp_rank].append(req_item)
if not scheduled:
remaining_unscheduled.append(req_item)
# Balance the remaining unscheduled requests across ranks
num_new_requests_all_ranks = len(remaining_unscheduled)
total_num_active_requests = sum(all_ranks_num_active_requests)
self.expected_num_active_requests = max(
(total_num_active_requests + num_new_requests_all_ranks +
self.dist.tp_size - 1) // self.dist.tp_size,
max(all_ranks_num_active_requests),
)
all_ranks_new_requests = self._balance_requests_across_ranks(
remaining_unscheduled, all_ranks_new_requests,
all_ranks_num_active_requests)
return all_ranks_new_requests
def _handle_request_broadcasting(self,
new_requests: List[RequestQueueItem]):
"""Handle broadcasting of requests and Python objects across ranks."""
@ -274,8 +372,13 @@ class ExecutorRequestQueue:
new_requests, "py_logits_post_processors")
py_multimodal_data = self._collect_py_objects_from_requests(
new_requests, "py_multimodal_data")
py_scheduling_params = self._collect_py_objects_from_requests(
new_requests, "py_scheduling_params")
py_request_objects = tuple(
filter(None, [py_logits_post_processors, py_multimodal_data]))
filter(None, [
py_logits_post_processors, py_multimodal_data,
py_scheduling_params
]))
else:
py_request_objects = None
@ -314,28 +417,30 @@ class ExecutorRequestQueue:
def _balance_requests_across_ranks(
self, new_requests: List[RequestQueueItem],
all_ranks_new_requests: Dict[int, List[RequestQueueItem]],
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
"""Balance requests across ranks for attention DP."""
new_requests_cur_rank = []
if new_requests and self.expected_num_active_requests > all_ranks_num_active_requests[
self.dist.tp_rank]:
if new_requests:
# Balance context tokens across ranks using heap
HeapVal = namedtuple(
'HeapVal',
['num_tokens', 'num_requests', 'rank', 'request_list'])
all_ranks_new_requests_heap = [
HeapVal(0, self.expected_num_active_requests - val, tp_rank, [])
HeapVal(0, val, tp_rank, [])
for tp_rank, val in enumerate(all_ranks_num_active_requests)
]
new_requests_cur_rank = all_ranks_new_requests_heap[
self.dist.tp_rank].request_list
all_ranks_new_requests_heap = [
val for val in all_ranks_new_requests_heap
if val.num_requests > 0
if val.num_requests < self.expected_num_active_requests
]
all_ranks_new_scheduled_requests = {
val.rank: val.request_list
for val in all_ranks_new_requests_heap
}
heapq.heapify(all_ranks_new_requests_heap)
# Sort by token count (descending) for better load balancing
@ -351,17 +456,22 @@ class ExecutorRequestQueue:
token_count = len(
getattr(req_item.request, 'input_token_ids',
[])) if req_item.request else 0
# Update the heap value with the new request
val = val._replace(
num_tokens=val.num_tokens + token_count,
num_requests=val.num_requests - 1,
num_requests=val.num_requests + 1,
)
val.request_list.append(req_item)
if val.num_requests > 0:
heapq.heappush(all_ranks_new_requests_heap, val)
elif val.rank == self.dist.tp_rank:
break
return new_requests_cur_rank
val.request_list.append(req_item)
# If rank still has room for new requests, push back into heap
if val.num_requests < self.expected_num_active_requests:
heapq.heappush(all_ranks_new_requests_heap, val)
# Extend all_ranks_new_requests with the new requests that have been scheduled
for rank, reqs in all_ranks_new_scheduled_requests.items():
all_ranks_new_requests[rank].extend(reqs)
return all_ranks_new_requests
def _collect_py_objects_from_requests(
self, requests: List[RequestQueueItem],

View File

@ -29,6 +29,7 @@ from ..llmapi.utils import (AsyncQueue, enable_llm_debug,
print_colored_debug)
from ..sampling_params import (BatchedLogitsProcessor, LogprobParams,
SamplingParams)
from ..scheduling_params import SchedulingParams
from .ipc import FusedIpcQueue
from .postproc_worker import PostprocParams, PostprocWorkerConfig
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
@ -120,6 +121,7 @@ class GenerationExecutor(ABC):
disaggregated_params: Optional[DisaggregatedParams] = None,
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
) -> GenerationResult:
"""Generate output for the given prompt token ids in the asynchronous mode.
Asynchronous generation accepts single prompt only.
@ -142,7 +144,8 @@ class GenerationExecutor(ABC):
streaming=streaming,
kv_cache_retention_config=kv_cache_retention_config,
disaggregated_params=disaggregated_params,
multimodal_params=multimodal_params)
multimodal_params=multimodal_params,
scheduling_params=scheduling_params)
result = self.submit(request)
# release memory in time
if hasattr(request, "multimodal_params"):

View File

@ -10,6 +10,7 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.llm_utils import KvCacheRetentionConfig
from ..sampling_params import SamplingParams
from ..scheduling_params import SchedulingParams
from .postproc_worker import PostprocParams
__all__ = [
@ -95,6 +96,7 @@ class GenerationRequest:
disaggregated_params: Optional[DisaggregatedParams] = None,
postproc_params: Optional[PostprocParams] = None,
multimodal_params: Optional[MultimodalParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
):
if isinstance(prompt_token_ids, list):
self.prompt_token_ids = prompt_token_ids
@ -119,6 +121,7 @@ class GenerationRequest:
self.kv_cache_retention_config = kv_cache_retention_config
self.id: Optional[int] = None
self.disaggregated_params = disaggregated_params
self.scheduling_params = scheduling_params
def set_id(self, id):
assert self.id is None, f"Request ID is already set: {self.id}"

View File

@ -520,6 +520,10 @@ class GenerationExecutorWorker(GenerationExecutor):
executor_request.py_logits_post_processors = lp if isinstance(
lp, list) else [lp]
executor_request.py_scheduling_params = None
if self._is_pytorch_backend and request.scheduling_params is not None:
executor_request.py_scheduling_params = request.scheduling_params
if request.query_token_ids is not None:
# pytorch star attention workflow
# a workaround to avoid public interface update

View File

@ -30,6 +30,7 @@ from ..inputs import (PromptInputs, create_input_processor,
create_input_processor_with_hash, prompt_inputs)
from ..logger import logger
from ..sampling_params import SamplingParams
from ..scheduling_params import SchedulingParams
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig,
PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs)
@ -236,6 +237,8 @@ class BaseLLM:
KvCacheRetentionConfig, Sequence[KvCacheRetentionConfig]]] = None,
disaggregated_params: Optional[Union[
DisaggregatedParams, Sequence[DisaggregatedParams]]] = None,
scheduling_params: Optional[Union[SchedulingParams,
List[SchedulingParams]]] = None,
) -> Union[RequestOutput, List[RequestOutput]]:
"""Generate output for the given prompts in the synchronous mode.
Synchronous generation accepts either single prompt or batched prompts.
@ -254,6 +257,8 @@ class BaseLLM:
Configuration for the request's retention in the KV Cache. Defaults to None.
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, Sequence[tensorrt_llm.disaggregated_params.DisaggregatedParams], optional):
Disaggregated parameters. Defaults to None.
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, List[tensorrt_llm.scheduling_params.SchedulingParams], optional):
Scheduling parameters. Defaults to None.
Returns:
Union[tensorrt_llm.llmapi.RequestOutput, List[tensorrt_llm.llmapi.RequestOutput]]: The output data of the completion request to the LLM.
"""
@ -283,6 +288,7 @@ class BaseLLM:
kv_cache_retention_config=_item_at(kv_cache_retention_config,
i),
disaggregated_params=_item_at(disaggregated_params, i),
scheduling_params=_item_at(scheduling_params, i),
streaming=False)
futures.append(future)
@ -308,6 +314,7 @@ class BaseLLM:
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
disaggregated_params: Optional[DisaggregatedParams] = None,
_postproc_params: Optional[PostprocParams] = None,
scheduling_params: Optional[SchedulingParams] = None,
) -> RequestOutput:
"""Generate output for the given prompt in the asynchronous mode.
Asynchronous generation accepts single prompt only.
@ -321,6 +328,7 @@ class BaseLLM:
streaming (bool): Whether to use the streaming mode for the generation. Defaults to False.
kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, optional): Configuration for the request's retention in the KV Cache. Defaults to None.
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Disaggregated parameters. Defaults to None.
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, optional): Scheduling parameters. Defaults to None.
Returns:
tensorrt_llm.llmapi.RequestOutput: The output data of the completion request to the LLM.
@ -426,6 +434,7 @@ class BaseLLM:
disaggregated_params=disaggregated_params,
postproc_params=_postproc_params,
multimodal_params=multimodal_params,
scheduling_params=scheduling_params,
)
return RequestOutput._from_generation_result(result, prompt,

View File

@ -0,0 +1,15 @@
from dataclasses import dataclass
from typing import Optional
@dataclass(slots=True, kw_only=True)
class SchedulingParams:
"""Schedule parameters.
Args:
attention_dp_rank (int): The rank of target attention dp
attention_dp_relax (bool): Whether to allow the request to be scheduled to other attention dp for better throughput
"""
attention_dp_rank: Optional[int] = None
attention_dp_relax: Optional[bool] = None

View File

@ -196,7 +196,7 @@ def test_get_from_waiting_queue(executor_queue):
# Get 3 items
result = executor_queue._get_from_waiting_queue(
executor_queue.waiting_queue, 3)
executor_queue.waiting_queue, 3, enable_attention_dp=False)
assert len(result) == 3
assert result == items[:3]
@ -221,7 +221,7 @@ def test_get_from_waiting_queue_edge_cases(executor_queue, queue_size,
executor_queue.waiting_queue.extend(items)
result = executor_queue._get_from_waiting_queue(
executor_queue.waiting_queue, request_count)
executor_queue.waiting_queue, request_count, enable_attention_dp=False)
assert len(result) == expected_result
assert len(executor_queue.waiting_queue) == expected_remaining
@ -316,141 +316,530 @@ def test_clear_canceled_req_ids(executor_queue):
assert len(executor_queue.canceled_req_ids) == 0
def test_thread_safety(executor_queue):
"""Test thread safety of enqueue operations."""
results = []
errors = []
def enqueue_worker():
try:
for i in range(10):
req_id = executor_queue.enqueue_request(Mock())
results.append(req_id)
except Exception as e:
errors.append(e)
# Create multiple threads
threads = []
for _ in range(3):
thread = threading.Thread(target=enqueue_worker)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Check results
assert len(errors) == 0
assert len(results) == 30
assert len(set(results)) == 30 # All IDs should be unique
@pytest.fixture
def mock_dist_attention_dp():
"""Create a mock Distributed instance for testing."""
mock_dist = Mock()
mock_dist.rank = 0
mock_dist.tp_size = 4
mock_dist.pp_size = 1
mock_dist.has_pp = False
mock_dist.tp_rank = 0
mock_dist.cp_rank = 0
mock_dist.cp_size = 1
mock_dist.cp_config = {}
mock_dist.is_first_pp_rank = True
mock_dist.is_last_pp_rank = True
mock_dist.next_pp_rank = 1
mock_dist.prev_pp_rank = 0
mock_dist.broadcast = Mock(return_value=([], None))
return mock_dist
@patch('tensorrt_llm._torch.pyexecutor.executor_request_queue.time.time')
def test_update_new_active_requests_queue_latency(mock_time, executor_queue):
"""Test updating queue latency metrics."""
mock_time.return_value = 1000.0
# Set up start times
executor_queue.start_times = {1: 998.0, 2: 999.0}
requests = [RequestQueueItem(1, Mock()), RequestQueueItem(2, Mock())]
executor_queue._update_new_active_requests_queue_latency(requests)
# Check latency was updated (1000.0 - 998.0) + (1000.0 - 999.0) = 3.0
assert executor_queue.new_active_requests_queue_latency_ms == 3.0
# Check start times were removed
assert len(executor_queue.start_times) == 0
@pytest.fixture
def attention_dp_queue(mock_dist_attention_dp):
"""Create an ExecutorRequestQueue instance for attention DP testing."""
queue = ExecutorRequestQueue(dist=mock_dist_attention_dp,
enable_attention_dp=True,
max_batch_size=4,
max_beam_width=2,
max_num_active_requests=8,
enable_iter_perf_stats=True,
is_disaggregated=False)
# Initialize all_ranks_num_active_requests
return queue
@pytest.mark.parametrize("enable_attention_dp", [False, True])
def test_fetch_new_requests_routing(executor_queue, enable_attention_dp):
"""Test that fetch_new_requests routes correctly based on attention_dp setting."""
mock_active_requests = []
executor_queue.enable_attention_dp = enable_attention_dp
@pytest.fixture
def all_ranks_num_active_requests():
return [2, 1, 3, 0] # 4 ranks
if enable_attention_dp:
with patch.object(executor_queue,
'_fetch_new_requests_attention_dp') as mock_dp:
mock_dp.return_value = []
executor_queue.fetch_new_requests(len(mock_active_requests))
mock_dp.assert_called_once_with(len(mock_active_requests))
def create_mock_request_with_py_schedule_params(attention_dp_rank=None,
attention_dp_relax=False):
mock_request = Mock()
if attention_dp_rank is not None:
mock_schedule_params = Mock()
mock_schedule_params.attention_dp_rank = attention_dp_rank
mock_schedule_params.attention_dp_relax = attention_dp_relax
mock_schedule_params.configure_mock(
attention_dp_rank=attention_dp_rank,
attention_dp_relax=attention_dp_relax)
mock_request.py_scheduling_params = mock_schedule_params
else:
with patch.object(executor_queue,
'_fetch_new_requests_attention_tp') as mock_tp:
mock_tp.return_value = []
executor_queue.fetch_new_requests(len(mock_active_requests))
mock_tp.assert_called_once_with(len(mock_active_requests))
mock_request.py_scheduling_params = None
mock_request.input_token_ids = [1, 2, 3]
return mock_request
# Integration tests
def test_full_workflow(integration_queue):
"""Test a complete workflow from enqueue to processing."""
# Enqueue some requests - create mocks without sampling_config to avoid beam validation
mock_requests = []
for _ in range(3):
mock_req = Mock()
delattr(mock_req, 'sampling_config') if hasattr(
mock_req, 'sampling_config') else None
mock_requests.append(mock_req)
req_ids = integration_queue.enqueue_requests(mock_requests) # type: ignore
# Unit tests for _schedule_attention_dp_requests
def test_schedule_attention_dp_requests_scheduled_requests(
attention_dp_queue, all_ranks_num_active_requests):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=False))
req2 = RequestQueueItem(
2,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=False))
# Enqueue a cancel request
integration_queue.enqueue_cancel_request(req_ids[1])
new_requests = [req1, req2]
# Simulate fetching from request queue
items = []
while not integration_queue.request_queue.empty():
try:
items.append(integration_queue.request_queue.get_nowait())
except queue.Empty:
break
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
result = all_ranks_new_requests[0]
assert len(items) == 4 # 3 requests + 1 cancel
assert len(result) == 2
assert req1 in result
assert req2 in result
# Filter and validate
valid_items = integration_queue._validate_and_filter_requests(items)
assert len(valid_items) == 3
assert req_ids[1] in integration_queue.canceled_req_ids
assert all_ranks_num_active_requests[0] == 4
@patch(
'tensorrt_llm._torch.pyexecutor.executor_request_queue.executor_request_to_llm_request'
)
def test_merge_requests_with_beam_validation(mock_convert, integration_queue):
"""Test request merging with beam width validation."""
# Create mock requests with different beam widths
mock_req1 = Mock()
mock_req1.sampling_config = Mock()
mock_req1.sampling_config.beam_width = 2 # Matches max_beam_width
def test_schedule_attention_dp_requests_scheduled_requests_other_ranks(
attention_dp_queue, all_ranks_num_active_requests):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
attention_dp_relax=False))
req2 = RequestQueueItem(
2,
create_mock_request_with_py_schedule_params(attention_dp_rank=2,
attention_dp_relax=False))
mock_req2 = Mock()
mock_req2.sampling_config = Mock()
mock_req2.sampling_config.beam_width = 3 # Doesn't match max_beam_width
new_requests = [req1, req2]
requests = [RequestQueueItem(1, mock_req1), RequestQueueItem(2, mock_req2)]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
# First request should pass validation
valid_requests = integration_queue._validate_and_filter_requests(
[requests[0]])
assert len(valid_requests) == 1
result = all_ranks_new_requests[0]
assert len(result) == 0
# Second request should fail validation
with pytest.raises(AssertionError):
integration_queue._validate_and_filter_requests([requests[1]])
assert all_ranks_num_active_requests[1] == 2
assert all_ranks_num_active_requests[2] == 4
def test_beam_width_validation_success(integration_queue):
"""Test that beam width validation passes for correct beam width."""
mock_req = Mock()
mock_req.sampling_config = Mock()
mock_req.sampling_config.beam_width = 2 # Matches integration test max_beam_width
def test_schedule_attention_dp_requests_unscheduled_requests(
attention_dp_queue, all_ranks_num_active_requests):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=True))
req2 = RequestQueueItem(
2,
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
attention_dp_relax=True))
request = RequestQueueItem(1, mock_req)
valid_requests = integration_queue._validate_and_filter_requests([request])
new_requests = [req1, req2]
assert len(valid_requests) == 1
assert valid_requests[0] == request
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
result = all_ranks_new_requests[0]
assert len(result) == 1 # Only req1 for current rank
assert req1 in result
def test_schedule_attention_dp_requests_unscheduled_no_capacity(
attention_dp_queue, all_ranks_num_active_requests):
all_ranks_num_active_requests[0] = 8
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=True))
new_requests = [req1]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
result = all_ranks_new_requests[0]
assert len(result) == 0 # No capacity
def test_schedule_attention_dp_requests_mixed_scenarios(
attention_dp_queue, all_ranks_num_active_requests):
req_scheduled_current = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=False))
req_scheduled_other = RequestQueueItem(
2,
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
attention_dp_relax=False))
req_unscheduled_current = RequestQueueItem(
3,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=True))
req_unscheduled_other = RequestQueueItem(
4,
create_mock_request_with_py_schedule_params(attention_dp_rank=2,
attention_dp_relax=True))
new_requests = [
req_scheduled_current, req_scheduled_other, req_unscheduled_current,
req_unscheduled_other
]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
result = all_ranks_new_requests[0]
assert len(result) == 2
assert req_scheduled_current in result
assert req_unscheduled_current in result
def test_schedule_attention_dp_requests_empty_lists(
attention_dp_queue, all_ranks_num_active_requests):
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
[], all_ranks_num_active_requests)
result = all_ranks_new_requests[0]
assert len(result) == 0
def test_schedule_attention_dp_requests_expected_num_active_calculation(
attention_dp_queue, all_ranks_num_active_requests):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=True))
req2 = RequestQueueItem(
2,
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
attention_dp_relax=True))
new_requests = [req1, req2]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
all_ranks_new_requests[0]
# 2 + 1 + 3 + 0 = 6, 6 + 2 = 8, (8 + 3) // 4 = 2, max(2, 2, 1, 3, 0) = 3
assert attention_dp_queue.expected_num_active_requests == 3
def test_schedule_attention_dp_requests_balance_requests_called(
attention_dp_queue, all_ranks_num_active_requests):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=True))
new_requests = [req1]
with patch.object(attention_dp_queue,
'_balance_requests_across_ranks') as mock_balance:
mock_balance.return_value = {0: req1}
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
all_ranks_new_requests[0]
# Check that _balance_requests_across_ranks was called
mock_balance.assert_called_once()
call_args = mock_balance.call_args[0]
assert isinstance(call_args[0], list)
assert isinstance(call_args[1], dict)
assert call_args[2] == all_ranks_num_active_requests # Third arg
def test_schedule_attention_dp_requests_no_scheduling_when_capacity_exceeded(
attention_dp_queue, all_ranks_num_active_requests):
all_ranks_num_active_requests[0] = 8
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=False))
new_requests = [req1]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
result = all_ranks_new_requests[0]
assert len(result) == 0 # No requests scheduled
assert all_ranks_num_active_requests[0] == 8 # Capacity unchanged
# Integration tests combining both methods
def test_filter_and_schedule_integration(attention_dp_queue,
all_ranks_num_active_requests):
req_schedulable = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=False))
req_schedulable.request.input_token_ids = [1, 2, 3, 4]
req_relax = RequestQueueItem(
2,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=True))
req_relax.request.input_token_ids = [1, 2]
req_no_params = RequestQueueItem(
3, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
new_requests = [req_schedulable, req_relax, req_no_params]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
result = all_ranks_new_requests[0]
assert len(result) == 2
assert req_schedulable in result
assert req_relax in result
def test_filter_and_schedule_with_capacity_limits(
attention_dp_queue, all_ranks_num_active_requests):
all_ranks_num_active_requests[0] = 7
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=False))
req1.request.input_token_ids = [1, 2, 3, 4]
req2 = RequestQueueItem(
2,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=False))
req2.request.input_token_ids = [1, 2, 3]
new_requests = [req1, req2]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
result = all_ranks_new_requests[0]
assert len(result) == 1
assert req1 in result
def test_get_from_waiting_queue_with_attention_dp(
attention_dp_queue, all_ranks_num_active_requests):
items = [RequestQueueItem(i, Mock()) for i in range(5)]
attention_dp_queue.waiting_queue.extend(items)
result = attention_dp_queue._get_from_waiting_queue(
attention_dp_queue.waiting_queue, 3, True,
all_ranks_num_active_requests)
assert len(result) == 3
assert result == items[:3]
assert len(attention_dp_queue.waiting_queue) == 2
def test_get_from_waiting_queue_with_attention_dp_filtering(
attention_dp_queue, all_ranks_num_active_requests):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=False))
req2 = RequestQueueItem(
2,
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
attention_dp_relax=True))
req3 = RequestQueueItem(3,
create_mock_request_with_py_schedule_params(
attention_dp_rank=None)) # No scheduling params
attention_dp_queue.waiting_queue.extend([req1, req2, req3])
# Set rank 0 to full capacity to test filtering
all_ranks_num_active_requests[0] = 8
result = attention_dp_queue._get_from_waiting_queue(
attention_dp_queue.waiting_queue, 3, True,
all_ranks_num_active_requests)
assert len(result) == 2
assert req2 in result
assert req3 in result
assert req1 not in result
def test_can_process_attention_dp_request(attention_dp_queue):
req_no_params = RequestQueueItem(1, Mock())
assert attention_dp_queue._can_process_attention_dp_request(
req_no_params, [0, 0, 0, 0]) == True
req_relax = RequestQueueItem(
2,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=True))
assert attention_dp_queue._can_process_attention_dp_request(
req_relax, [0, 0, 0, 0]) == True
req_target = RequestQueueItem(
3,
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
attention_dp_relax=False))
all_ranks = [0, 0, 0, 0]
assert attention_dp_queue._can_process_attention_dp_request(
req_target, all_ranks) == True
assert all_ranks[1] == 1
req_no_capacity = RequestQueueItem(
4,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
attention_dp_relax=False))
all_ranks_full = [8, 0, 0, 0] # Rank 0 is at capacity
assert attention_dp_queue._can_process_attention_dp_request(
req_no_capacity, all_ranks_full) == False
def test_achieve_max_num_active_requests(attention_dp_queue):
req_list = []
req_id = 0
for rank in range(4):
for i in range(5):
req_list.append(
RequestQueueItem(
req_id,
create_mock_request_with_py_schedule_params(
attention_dp_rank=rank, attention_dp_relax=False)))
req_id += 1
req_list.append(
RequestQueueItem(
req_id,
create_mock_request_with_py_schedule_params(
attention_dp_rank=rank, attention_dp_relax=True)))
req_id += 1
all_ranks_num_active_requests = [5, 6, 3, 7]
attention_dp_queue.waiting_queue.extend(req_list)
available_active_requests = attention_dp_queue.max_num_active_requests * 4 - sum(
all_ranks_num_active_requests)
result = attention_dp_queue._get_from_waiting_queue(
attention_dp_queue.waiting_queue, available_active_requests, True,
all_ranks_num_active_requests)
assert len(result) == available_active_requests
def append_to_waiting_queue(waiting_queue, rank, attention_dp_relax):
req_id = len(waiting_queue)
waiting_queue.append(
RequestQueueItem(
req_id,
create_mock_request_with_py_schedule_params(
attention_dp_rank=rank, attention_dp_relax=attention_dp_relax)))
@pytest.mark.parametrize(
"max_num_active_requests,all_ranks_num_active_requests,request_configs,all_ranks_expected_req_ids",
[
# Case: Balanced distribution of relaxed requests
(3, [0, 0, 0, 0], [(None, True)] * 7, {
0: [0, 4],
1: [1, 5],
2: [2, 6],
3: [3]
}),
# Case: Balanced distribution of relaxed requests
(3, [1, 2, 3, 0], [(None, True)] * 13, {
0: [1, 4],
1: [2],
2: [],
3: [0, 3, 5]
}),
# Case: Limited by max active
(3, [0, 0, 0, 0], [(None, True)] * 13, {
0: [0, 4, 8],
1: [1, 5, 9],
2: [2, 6, 10],
3: [3, 7, 11]
}),
# Case: Empty new requests
(3, [3, 3, 3, 0], [], {
0: [],
1: [],
2: [],
3: []
}),
# Case: Rank 0 is full and cannot schedule attention_dp rank request
(3, [3, 1, 3, 0], [(0, False), (0, True)], {
0: [],
1: [],
2: [],
3: [1]
}),
# Case: Only room for 1 request, need to skip req0 with attention dp rank
(3, [3, 2, 3, 3], [(0, False), (0, True)], {
0: [],
1: [1],
2: [],
3: []
}),
# Case: Targeting ranks 1 and 3 that have room
(3, [2, 1, 3, 0], [(1, False), (3, False)], {
0: [],
1: [0],
2: [],
3: [1]
}),
# Case: Target dp rank specified, by relax is True
(3, [3, 3, 3, 1], [(0, True), (1, True), (2, True)], {
0: [],
1: [],
2: [],
3: [0, 1]
}),
# Case:
(3, [3, 3, 3, 0], [(0, False), (1, True), (3, False)], {
0: [],
1: [],
2: [],
3: [2, 1]
}),
])
def test_attention_dp_scheduling_cases(attention_dp_queue,
max_num_active_requests,
all_ranks_num_active_requests,
request_configs,
all_ranks_expected_req_ids):
"""Test attention DP scheduling with various scenarios."""
attention_dp_queue.max_num_active_requests = max_num_active_requests
waiting_queue = deque()
for rank, relax in request_configs:
append_to_waiting_queue(waiting_queue, rank, relax)
run_test_attention_dp_scheduling(attention_dp_queue, waiting_queue,
all_ranks_num_active_requests,
all_ranks_expected_req_ids)
def run_test_attention_dp_scheduling(attention_dp_queue, waiting_queue,
all_ranks_num_active_requests,
all_ranks_expected_req_ids):
num_ranks = len(all_ranks_num_active_requests)
total_num_active_requests = sum(all_ranks_num_active_requests)
total_max_num_active_requests = attention_dp_queue.max_num_active_requests * num_ranks
enable_attention_dp = True
new_requests = attention_dp_queue._get_from_waiting_queue(
waiting_queue,
total_max_num_active_requests - total_num_active_requests,
enable_attention_dp, all_ranks_num_active_requests)
# Schedule attention dp requests
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
assert len(all_ranks_new_requests) == num_ranks
print("all_ranks_new_requests:", all_ranks_new_requests)
for rank, reqs in all_ranks_new_requests.items():
req_ids = [req.id for req in reqs]
assert req_ids == all_ranks_expected_req_ids[rank]

View File

@ -156,6 +156,9 @@ methods:
kv_cache_retention_config:
annotation: Union[tensorrt_llm.bindings.executor.KvCacheRetentionConfig, Sequence[tensorrt_llm.bindings.executor.KvCacheRetentionConfig], NoneType]
default: null
scheduling_params:
annotation: Union[tensorrt_llm.scheduling_params.SchedulingParams, List[tensorrt_llm.scheduling_params.SchedulingParams], NoneType]
default: null
return_annotation: Union[tensorrt_llm.llmapi.llm.RequestOutput, List[tensorrt_llm.llmapi.llm.RequestOutput]]
generate_async:
parameters:
@ -165,6 +168,10 @@ methods:
kv_cache_retention_config:
annotation: Optional[tensorrt_llm.bindings.executor.KvCacheRetentionConfig]
default: null
scheduling_params:
annotation: Optional[tensorrt_llm.scheduling_params.SchedulingParams]
default: null
status: prototype
return_annotation: tensorrt_llm.llmapi.llm.RequestOutput
get_kv_cache_events:
parameters: