mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Add support of scheduling attention dp request (#6246)
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Co-authored-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
This commit is contained in:
parent
31802de0b0
commit
67a3fd858b
@ -87,27 +87,68 @@ class ExecutorRequestQueue:
|
||||
self,
|
||||
waiting_queue: deque[RequestQueueItem],
|
||||
max_req_count: int,
|
||||
enable_attention_dp: bool,
|
||||
all_ranks_num_active_requests: Optional[List[int]] = None,
|
||||
) -> List[RequestQueueItem]:
|
||||
"""Safely extracts up to max_req_count items from a deque.
|
||||
|
||||
"""
|
||||
Args:
|
||||
waiting_queue: The queue to pop items from.
|
||||
max_req_count: Maximum items to retrieve. Returns empty list if <=0.
|
||||
|
||||
enable_attention_dp: Whether to enable attention DP scheduling.
|
||||
all_ranks_num_active_requests: Number of active requests for each rank.
|
||||
Returns:
|
||||
List of retrieved items (may be shorter than max_req_count if queue empties first).
|
||||
List of requests that can be processed.
|
||||
"""
|
||||
# Edge case handling
|
||||
if max_req_count <= 0: # Handles negative/zero counts
|
||||
|
||||
if max_req_count <= 0:
|
||||
return []
|
||||
|
||||
items = []
|
||||
req_count = 0
|
||||
items = []
|
||||
pending_requests = []
|
||||
|
||||
# Track the request with strict requirements
|
||||
scheduling_all_ranks_num_active_requests = all_ranks_num_active_requests.copy(
|
||||
) if enable_attention_dp else None
|
||||
while req_count < max_req_count and waiting_queue:
|
||||
items.append(waiting_queue.popleft())
|
||||
req_count += 1
|
||||
req_item = waiting_queue.popleft()
|
||||
can_process = self._can_process_attention_dp_request(
|
||||
req_item, scheduling_all_ranks_num_active_requests
|
||||
) if enable_attention_dp else True
|
||||
|
||||
if can_process:
|
||||
items.append(req_item)
|
||||
req_count += 1
|
||||
else:
|
||||
pending_requests.append(req_item)
|
||||
|
||||
# Put the pending requests back to the waiting queue
|
||||
# All ranks should have the same waiting queue
|
||||
waiting_queue.extendleft(reversed(pending_requests))
|
||||
|
||||
return items
|
||||
|
||||
def _can_process_attention_dp_request(
|
||||
self, req_item: RequestQueueItem,
|
||||
all_ranks_num_active_requests: List[int]) -> bool:
|
||||
"""Return True if the request can be processed immediately, else False."""
|
||||
|
||||
scheduling_params = getattr(req_item.request, 'py_scheduling_params',
|
||||
None)
|
||||
if scheduling_params is None:
|
||||
return True
|
||||
|
||||
target_dp_rank = scheduling_params.attention_dp_rank
|
||||
if target_dp_rank is None or scheduling_params.attention_dp_relax:
|
||||
return True
|
||||
|
||||
if all_ranks_num_active_requests[
|
||||
target_dp_rank] < self.max_num_active_requests:
|
||||
all_ranks_num_active_requests[target_dp_rank] += 1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def enqueue_requests(self, requests: List[ExecutorRequest]):
|
||||
req_ids = []
|
||||
try:
|
||||
@ -166,8 +207,12 @@ class ExecutorRequestQueue:
|
||||
return can_enqueue and self.dist.rank == 0
|
||||
|
||||
def _fetch_and_process_requests(
|
||||
self, total_num_active_requests: int,
|
||||
total_max_num_active_requests: int) -> List[RequestQueueItem]:
|
||||
self,
|
||||
total_num_active_requests: int,
|
||||
total_max_num_active_requests: int,
|
||||
enable_attention_dp: bool,
|
||||
all_ranks_num_active_requests: Optional[List[int]] = None
|
||||
) -> List[RequestQueueItem]:
|
||||
"""Common logic for fetching and processing requests from the queue."""
|
||||
# Calculate timeout
|
||||
timeout = None if (total_num_active_requests == 0) and len(
|
||||
@ -195,7 +240,8 @@ class ExecutorRequestQueue:
|
||||
|
||||
new_requests = self._get_from_waiting_queue(
|
||||
self.waiting_queue,
|
||||
total_max_num_active_requests - total_num_active_requests)
|
||||
total_max_num_active_requests - total_num_active_requests,
|
||||
enable_attention_dp, all_ranks_num_active_requests)
|
||||
|
||||
# Update performance metrics
|
||||
if self.enable_iter_perf_stats and self.dist.rank == 0:
|
||||
@ -218,9 +264,11 @@ class ExecutorRequestQueue:
|
||||
total_num_active_requests = num_active_requests
|
||||
total_max_num_active_requests = self.max_num_active_requests
|
||||
|
||||
# Use common request fetching logic
|
||||
# fetch and process requests into waiting queue
|
||||
new_requests = self._fetch_and_process_requests(
|
||||
total_num_active_requests, total_max_num_active_requests)
|
||||
total_num_active_requests,
|
||||
total_max_num_active_requests,
|
||||
enable_attention_dp=False)
|
||||
|
||||
# Merge requests and add to active list
|
||||
merged_requests = self._merge_requests(new_requests)
|
||||
@ -238,20 +286,17 @@ class ExecutorRequestQueue:
|
||||
total_num_active_requests = sum(all_ranks_num_active_requests)
|
||||
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
|
||||
|
||||
# Use common request fetching logic
|
||||
# fetch and process requests into waiting queue
|
||||
new_requests = self._fetch_and_process_requests(
|
||||
total_num_active_requests, total_max_num_active_requests)
|
||||
total_num_active_requests,
|
||||
total_max_num_active_requests,
|
||||
enable_attention_dp=True,
|
||||
all_ranks_num_active_requests=all_ranks_num_active_requests)
|
||||
|
||||
# Balance requests across ranks
|
||||
num_new_requests_all_ranks = len(new_requests)
|
||||
self.expected_num_active_requests = max(
|
||||
(total_num_active_requests + num_new_requests_all_ranks +
|
||||
self.dist.tp_size - 1) // self.dist.tp_size,
|
||||
max(all_ranks_num_active_requests),
|
||||
)
|
||||
|
||||
new_requests_cur_rank = self._balance_requests_across_ranks(
|
||||
# Schedule attention dp requests
|
||||
all_ranks_new_requests = self._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]
|
||||
|
||||
# Update performance metrics
|
||||
if self.enable_iter_perf_stats and self.start_times:
|
||||
@ -259,13 +304,66 @@ class ExecutorRequestQueue:
|
||||
new_requests_cur_rank)
|
||||
|
||||
# Update counters
|
||||
self.num_fetch_requests += num_new_requests_all_ranks
|
||||
self.num_fetch_requests += len(new_requests)
|
||||
self.num_fetch_requests_cur_rank += len(new_requests_cur_rank)
|
||||
|
||||
# Merge requests and add to active list
|
||||
new_requests_cur_rank = self._merge_requests(new_requests_cur_rank)
|
||||
return new_requests_cur_rank
|
||||
|
||||
def _schedule_attention_dp_requests(
|
||||
self, new_requests: List[RequestQueueItem],
|
||||
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
|
||||
"""Schedule attention dp requests."""
|
||||
|
||||
# Map from ranks to new requests
|
||||
all_ranks_new_requests = {
|
||||
tp_rank: []
|
||||
for tp_rank in range(self.dist.tp_size)
|
||||
}
|
||||
|
||||
# Prioritize the requests that are not in relax mode
|
||||
def get_relax_value(req_item):
|
||||
scheduling_params = getattr(req_item.request,
|
||||
'py_scheduling_params', None)
|
||||
if scheduling_params is None:
|
||||
return True
|
||||
return scheduling_params.attention_dp_relax
|
||||
|
||||
new_requests = sorted(new_requests, key=get_relax_value, reverse=True)
|
||||
|
||||
# Try to put the requests to the target dp rank until the max_num_active_requests is reached
|
||||
remaining_unscheduled = []
|
||||
for req_item in new_requests:
|
||||
scheduled = False
|
||||
scheduling_params = getattr(req_item.request,
|
||||
'py_scheduling_params', None)
|
||||
if scheduling_params is not None:
|
||||
target_dp_rank = scheduling_params.attention_dp_rank
|
||||
if target_dp_rank is not None and all_ranks_num_active_requests[
|
||||
target_dp_rank] < self.max_num_active_requests:
|
||||
all_ranks_num_active_requests[target_dp_rank] += 1
|
||||
scheduled = True
|
||||
all_ranks_new_requests[target_dp_rank].append(req_item)
|
||||
|
||||
if not scheduled:
|
||||
remaining_unscheduled.append(req_item)
|
||||
|
||||
# Balance the remaining unscheduled requests across ranks
|
||||
num_new_requests_all_ranks = len(remaining_unscheduled)
|
||||
total_num_active_requests = sum(all_ranks_num_active_requests)
|
||||
self.expected_num_active_requests = max(
|
||||
(total_num_active_requests + num_new_requests_all_ranks +
|
||||
self.dist.tp_size - 1) // self.dist.tp_size,
|
||||
max(all_ranks_num_active_requests),
|
||||
)
|
||||
|
||||
all_ranks_new_requests = self._balance_requests_across_ranks(
|
||||
remaining_unscheduled, all_ranks_new_requests,
|
||||
all_ranks_num_active_requests)
|
||||
|
||||
return all_ranks_new_requests
|
||||
|
||||
def _handle_request_broadcasting(self,
|
||||
new_requests: List[RequestQueueItem]):
|
||||
"""Handle broadcasting of requests and Python objects across ranks."""
|
||||
@ -274,8 +372,13 @@ class ExecutorRequestQueue:
|
||||
new_requests, "py_logits_post_processors")
|
||||
py_multimodal_data = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_multimodal_data")
|
||||
py_scheduling_params = self._collect_py_objects_from_requests(
|
||||
new_requests, "py_scheduling_params")
|
||||
py_request_objects = tuple(
|
||||
filter(None, [py_logits_post_processors, py_multimodal_data]))
|
||||
filter(None, [
|
||||
py_logits_post_processors, py_multimodal_data,
|
||||
py_scheduling_params
|
||||
]))
|
||||
else:
|
||||
py_request_objects = None
|
||||
|
||||
@ -314,28 +417,30 @@ class ExecutorRequestQueue:
|
||||
|
||||
def _balance_requests_across_ranks(
|
||||
self, new_requests: List[RequestQueueItem],
|
||||
all_ranks_new_requests: Dict[int, List[RequestQueueItem]],
|
||||
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
|
||||
"""Balance requests across ranks for attention DP."""
|
||||
new_requests_cur_rank = []
|
||||
|
||||
if new_requests and self.expected_num_active_requests > all_ranks_num_active_requests[
|
||||
self.dist.tp_rank]:
|
||||
if new_requests:
|
||||
# Balance context tokens across ranks using heap
|
||||
HeapVal = namedtuple(
|
||||
'HeapVal',
|
||||
['num_tokens', 'num_requests', 'rank', 'request_list'])
|
||||
|
||||
all_ranks_new_requests_heap = [
|
||||
HeapVal(0, self.expected_num_active_requests - val, tp_rank, [])
|
||||
HeapVal(0, val, tp_rank, [])
|
||||
for tp_rank, val in enumerate(all_ranks_num_active_requests)
|
||||
]
|
||||
|
||||
new_requests_cur_rank = all_ranks_new_requests_heap[
|
||||
self.dist.tp_rank].request_list
|
||||
all_ranks_new_requests_heap = [
|
||||
val for val in all_ranks_new_requests_heap
|
||||
if val.num_requests > 0
|
||||
if val.num_requests < self.expected_num_active_requests
|
||||
]
|
||||
|
||||
all_ranks_new_scheduled_requests = {
|
||||
val.rank: val.request_list
|
||||
for val in all_ranks_new_requests_heap
|
||||
}
|
||||
|
||||
heapq.heapify(all_ranks_new_requests_heap)
|
||||
|
||||
# Sort by token count (descending) for better load balancing
|
||||
@ -351,17 +456,22 @@ class ExecutorRequestQueue:
|
||||
token_count = len(
|
||||
getattr(req_item.request, 'input_token_ids',
|
||||
[])) if req_item.request else 0
|
||||
# Update the heap value with the new request
|
||||
val = val._replace(
|
||||
num_tokens=val.num_tokens + token_count,
|
||||
num_requests=val.num_requests - 1,
|
||||
num_requests=val.num_requests + 1,
|
||||
)
|
||||
val.request_list.append(req_item)
|
||||
if val.num_requests > 0:
|
||||
heapq.heappush(all_ranks_new_requests_heap, val)
|
||||
elif val.rank == self.dist.tp_rank:
|
||||
break
|
||||
|
||||
return new_requests_cur_rank
|
||||
val.request_list.append(req_item)
|
||||
# If rank still has room for new requests, push back into heap
|
||||
if val.num_requests < self.expected_num_active_requests:
|
||||
heapq.heappush(all_ranks_new_requests_heap, val)
|
||||
|
||||
# Extend all_ranks_new_requests with the new requests that have been scheduled
|
||||
for rank, reqs in all_ranks_new_scheduled_requests.items():
|
||||
all_ranks_new_requests[rank].extend(reqs)
|
||||
|
||||
return all_ranks_new_requests
|
||||
|
||||
def _collect_py_objects_from_requests(
|
||||
self, requests: List[RequestQueueItem],
|
||||
|
||||
@ -29,6 +29,7 @@ from ..llmapi.utils import (AsyncQueue, enable_llm_debug,
|
||||
print_colored_debug)
|
||||
from ..sampling_params import (BatchedLogitsProcessor, LogprobParams,
|
||||
SamplingParams)
|
||||
from ..scheduling_params import SchedulingParams
|
||||
from .ipc import FusedIpcQueue
|
||||
from .postproc_worker import PostprocParams, PostprocWorkerConfig
|
||||
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
|
||||
@ -120,6 +121,7 @@ class GenerationExecutor(ABC):
|
||||
disaggregated_params: Optional[DisaggregatedParams] = None,
|
||||
postproc_params: Optional[PostprocParams] = None,
|
||||
multimodal_params: Optional[MultimodalParams] = None,
|
||||
scheduling_params: Optional[SchedulingParams] = None,
|
||||
) -> GenerationResult:
|
||||
"""Generate output for the given prompt token ids in the asynchronous mode.
|
||||
Asynchronous generation accepts single prompt only.
|
||||
@ -142,7 +144,8 @@ class GenerationExecutor(ABC):
|
||||
streaming=streaming,
|
||||
kv_cache_retention_config=kv_cache_retention_config,
|
||||
disaggregated_params=disaggregated_params,
|
||||
multimodal_params=multimodal_params)
|
||||
multimodal_params=multimodal_params,
|
||||
scheduling_params=scheduling_params)
|
||||
result = self.submit(request)
|
||||
# release memory in time
|
||||
if hasattr(request, "multimodal_params"):
|
||||
|
||||
@ -10,6 +10,7 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
from ..disaggregated_params import DisaggregatedParams
|
||||
from ..llmapi.llm_utils import KvCacheRetentionConfig
|
||||
from ..sampling_params import SamplingParams
|
||||
from ..scheduling_params import SchedulingParams
|
||||
from .postproc_worker import PostprocParams
|
||||
|
||||
__all__ = [
|
||||
@ -95,6 +96,7 @@ class GenerationRequest:
|
||||
disaggregated_params: Optional[DisaggregatedParams] = None,
|
||||
postproc_params: Optional[PostprocParams] = None,
|
||||
multimodal_params: Optional[MultimodalParams] = None,
|
||||
scheduling_params: Optional[SchedulingParams] = None,
|
||||
):
|
||||
if isinstance(prompt_token_ids, list):
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
@ -119,6 +121,7 @@ class GenerationRequest:
|
||||
self.kv_cache_retention_config = kv_cache_retention_config
|
||||
self.id: Optional[int] = None
|
||||
self.disaggregated_params = disaggregated_params
|
||||
self.scheduling_params = scheduling_params
|
||||
|
||||
def set_id(self, id):
|
||||
assert self.id is None, f"Request ID is already set: {self.id}"
|
||||
|
||||
@ -520,6 +520,10 @@ class GenerationExecutorWorker(GenerationExecutor):
|
||||
executor_request.py_logits_post_processors = lp if isinstance(
|
||||
lp, list) else [lp]
|
||||
|
||||
executor_request.py_scheduling_params = None
|
||||
if self._is_pytorch_backend and request.scheduling_params is not None:
|
||||
executor_request.py_scheduling_params = request.scheduling_params
|
||||
|
||||
if request.query_token_ids is not None:
|
||||
# pytorch star attention workflow
|
||||
# a workaround to avoid public interface update
|
||||
|
||||
@ -30,6 +30,7 @@ from ..inputs import (PromptInputs, create_input_processor,
|
||||
create_input_processor_with_hash, prompt_inputs)
|
||||
from ..logger import logger
|
||||
from ..sampling_params import SamplingParams
|
||||
from ..scheduling_params import SchedulingParams
|
||||
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
|
||||
TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig,
|
||||
PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs)
|
||||
@ -236,6 +237,8 @@ class BaseLLM:
|
||||
KvCacheRetentionConfig, Sequence[KvCacheRetentionConfig]]] = None,
|
||||
disaggregated_params: Optional[Union[
|
||||
DisaggregatedParams, Sequence[DisaggregatedParams]]] = None,
|
||||
scheduling_params: Optional[Union[SchedulingParams,
|
||||
List[SchedulingParams]]] = None,
|
||||
) -> Union[RequestOutput, List[RequestOutput]]:
|
||||
"""Generate output for the given prompts in the synchronous mode.
|
||||
Synchronous generation accepts either single prompt or batched prompts.
|
||||
@ -254,6 +257,8 @@ class BaseLLM:
|
||||
Configuration for the request's retention in the KV Cache. Defaults to None.
|
||||
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, Sequence[tensorrt_llm.disaggregated_params.DisaggregatedParams], optional):
|
||||
Disaggregated parameters. Defaults to None.
|
||||
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, List[tensorrt_llm.scheduling_params.SchedulingParams], optional):
|
||||
Scheduling parameters. Defaults to None.
|
||||
Returns:
|
||||
Union[tensorrt_llm.llmapi.RequestOutput, List[tensorrt_llm.llmapi.RequestOutput]]: The output data of the completion request to the LLM.
|
||||
"""
|
||||
@ -283,6 +288,7 @@ class BaseLLM:
|
||||
kv_cache_retention_config=_item_at(kv_cache_retention_config,
|
||||
i),
|
||||
disaggregated_params=_item_at(disaggregated_params, i),
|
||||
scheduling_params=_item_at(scheduling_params, i),
|
||||
streaming=False)
|
||||
futures.append(future)
|
||||
|
||||
@ -308,6 +314,7 @@ class BaseLLM:
|
||||
kv_cache_retention_config: Optional[KvCacheRetentionConfig] = None,
|
||||
disaggregated_params: Optional[DisaggregatedParams] = None,
|
||||
_postproc_params: Optional[PostprocParams] = None,
|
||||
scheduling_params: Optional[SchedulingParams] = None,
|
||||
) -> RequestOutput:
|
||||
"""Generate output for the given prompt in the asynchronous mode.
|
||||
Asynchronous generation accepts single prompt only.
|
||||
@ -321,6 +328,7 @@ class BaseLLM:
|
||||
streaming (bool): Whether to use the streaming mode for the generation. Defaults to False.
|
||||
kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, optional): Configuration for the request's retention in the KV Cache. Defaults to None.
|
||||
disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Disaggregated parameters. Defaults to None.
|
||||
scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, optional): Scheduling parameters. Defaults to None.
|
||||
|
||||
Returns:
|
||||
tensorrt_llm.llmapi.RequestOutput: The output data of the completion request to the LLM.
|
||||
@ -426,6 +434,7 @@ class BaseLLM:
|
||||
disaggregated_params=disaggregated_params,
|
||||
postproc_params=_postproc_params,
|
||||
multimodal_params=multimodal_params,
|
||||
scheduling_params=scheduling_params,
|
||||
)
|
||||
|
||||
return RequestOutput._from_generation_result(result, prompt,
|
||||
|
||||
15
tensorrt_llm/scheduling_params.py
Normal file
15
tensorrt_llm/scheduling_params.py
Normal file
@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass(slots=True, kw_only=True)
|
||||
class SchedulingParams:
|
||||
"""Schedule parameters.
|
||||
|
||||
Args:
|
||||
attention_dp_rank (int): The rank of target attention dp
|
||||
attention_dp_relax (bool): Whether to allow the request to be scheduled to other attention dp for better throughput
|
||||
"""
|
||||
|
||||
attention_dp_rank: Optional[int] = None
|
||||
attention_dp_relax: Optional[bool] = None
|
||||
@ -196,7 +196,7 @@ def test_get_from_waiting_queue(executor_queue):
|
||||
|
||||
# Get 3 items
|
||||
result = executor_queue._get_from_waiting_queue(
|
||||
executor_queue.waiting_queue, 3)
|
||||
executor_queue.waiting_queue, 3, enable_attention_dp=False)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result == items[:3]
|
||||
@ -221,7 +221,7 @@ def test_get_from_waiting_queue_edge_cases(executor_queue, queue_size,
|
||||
executor_queue.waiting_queue.extend(items)
|
||||
|
||||
result = executor_queue._get_from_waiting_queue(
|
||||
executor_queue.waiting_queue, request_count)
|
||||
executor_queue.waiting_queue, request_count, enable_attention_dp=False)
|
||||
|
||||
assert len(result) == expected_result
|
||||
assert len(executor_queue.waiting_queue) == expected_remaining
|
||||
@ -316,141 +316,530 @@ def test_clear_canceled_req_ids(executor_queue):
|
||||
assert len(executor_queue.canceled_req_ids) == 0
|
||||
|
||||
|
||||
def test_thread_safety(executor_queue):
|
||||
"""Test thread safety of enqueue operations."""
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def enqueue_worker():
|
||||
try:
|
||||
for i in range(10):
|
||||
req_id = executor_queue.enqueue_request(Mock())
|
||||
results.append(req_id)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
# Create multiple threads
|
||||
threads = []
|
||||
for _ in range(3):
|
||||
thread = threading.Thread(target=enqueue_worker)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Check results
|
||||
assert len(errors) == 0
|
||||
assert len(results) == 30
|
||||
assert len(set(results)) == 30 # All IDs should be unique
|
||||
@pytest.fixture
|
||||
def mock_dist_attention_dp():
|
||||
"""Create a mock Distributed instance for testing."""
|
||||
mock_dist = Mock()
|
||||
mock_dist.rank = 0
|
||||
mock_dist.tp_size = 4
|
||||
mock_dist.pp_size = 1
|
||||
mock_dist.has_pp = False
|
||||
mock_dist.tp_rank = 0
|
||||
mock_dist.cp_rank = 0
|
||||
mock_dist.cp_size = 1
|
||||
mock_dist.cp_config = {}
|
||||
mock_dist.is_first_pp_rank = True
|
||||
mock_dist.is_last_pp_rank = True
|
||||
mock_dist.next_pp_rank = 1
|
||||
mock_dist.prev_pp_rank = 0
|
||||
mock_dist.broadcast = Mock(return_value=([], None))
|
||||
return mock_dist
|
||||
|
||||
|
||||
@patch('tensorrt_llm._torch.pyexecutor.executor_request_queue.time.time')
|
||||
def test_update_new_active_requests_queue_latency(mock_time, executor_queue):
|
||||
"""Test updating queue latency metrics."""
|
||||
mock_time.return_value = 1000.0
|
||||
|
||||
# Set up start times
|
||||
executor_queue.start_times = {1: 998.0, 2: 999.0}
|
||||
|
||||
requests = [RequestQueueItem(1, Mock()), RequestQueueItem(2, Mock())]
|
||||
|
||||
executor_queue._update_new_active_requests_queue_latency(requests)
|
||||
|
||||
# Check latency was updated (1000.0 - 998.0) + (1000.0 - 999.0) = 3.0
|
||||
assert executor_queue.new_active_requests_queue_latency_ms == 3.0
|
||||
|
||||
# Check start times were removed
|
||||
assert len(executor_queue.start_times) == 0
|
||||
@pytest.fixture
|
||||
def attention_dp_queue(mock_dist_attention_dp):
|
||||
"""Create an ExecutorRequestQueue instance for attention DP testing."""
|
||||
queue = ExecutorRequestQueue(dist=mock_dist_attention_dp,
|
||||
enable_attention_dp=True,
|
||||
max_batch_size=4,
|
||||
max_beam_width=2,
|
||||
max_num_active_requests=8,
|
||||
enable_iter_perf_stats=True,
|
||||
is_disaggregated=False)
|
||||
# Initialize all_ranks_num_active_requests
|
||||
return queue
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_attention_dp", [False, True])
|
||||
def test_fetch_new_requests_routing(executor_queue, enable_attention_dp):
|
||||
"""Test that fetch_new_requests routes correctly based on attention_dp setting."""
|
||||
mock_active_requests = []
|
||||
executor_queue.enable_attention_dp = enable_attention_dp
|
||||
@pytest.fixture
|
||||
def all_ranks_num_active_requests():
|
||||
return [2, 1, 3, 0] # 4 ranks
|
||||
|
||||
if enable_attention_dp:
|
||||
with patch.object(executor_queue,
|
||||
'_fetch_new_requests_attention_dp') as mock_dp:
|
||||
mock_dp.return_value = []
|
||||
executor_queue.fetch_new_requests(len(mock_active_requests))
|
||||
mock_dp.assert_called_once_with(len(mock_active_requests))
|
||||
|
||||
def create_mock_request_with_py_schedule_params(attention_dp_rank=None,
|
||||
attention_dp_relax=False):
|
||||
mock_request = Mock()
|
||||
|
||||
if attention_dp_rank is not None:
|
||||
mock_schedule_params = Mock()
|
||||
mock_schedule_params.attention_dp_rank = attention_dp_rank
|
||||
mock_schedule_params.attention_dp_relax = attention_dp_relax
|
||||
|
||||
mock_schedule_params.configure_mock(
|
||||
attention_dp_rank=attention_dp_rank,
|
||||
attention_dp_relax=attention_dp_relax)
|
||||
|
||||
mock_request.py_scheduling_params = mock_schedule_params
|
||||
else:
|
||||
with patch.object(executor_queue,
|
||||
'_fetch_new_requests_attention_tp') as mock_tp:
|
||||
mock_tp.return_value = []
|
||||
executor_queue.fetch_new_requests(len(mock_active_requests))
|
||||
mock_tp.assert_called_once_with(len(mock_active_requests))
|
||||
mock_request.py_scheduling_params = None
|
||||
|
||||
mock_request.input_token_ids = [1, 2, 3]
|
||||
|
||||
return mock_request
|
||||
|
||||
|
||||
# Integration tests
|
||||
def test_full_workflow(integration_queue):
|
||||
"""Test a complete workflow from enqueue to processing."""
|
||||
# Enqueue some requests - create mocks without sampling_config to avoid beam validation
|
||||
mock_requests = []
|
||||
for _ in range(3):
|
||||
mock_req = Mock()
|
||||
delattr(mock_req, 'sampling_config') if hasattr(
|
||||
mock_req, 'sampling_config') else None
|
||||
mock_requests.append(mock_req)
|
||||
req_ids = integration_queue.enqueue_requests(mock_requests) # type: ignore
|
||||
# Unit tests for _schedule_attention_dp_requests
|
||||
def test_schedule_attention_dp_requests_scheduled_requests(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=False))
|
||||
req2 = RequestQueueItem(
|
||||
2,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=False))
|
||||
|
||||
# Enqueue a cancel request
|
||||
integration_queue.enqueue_cancel_request(req_ids[1])
|
||||
new_requests = [req1, req2]
|
||||
|
||||
# Simulate fetching from request queue
|
||||
items = []
|
||||
while not integration_queue.request_queue.empty():
|
||||
try:
|
||||
items.append(integration_queue.request_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(items) == 4 # 3 requests + 1 cancel
|
||||
assert len(result) == 2
|
||||
assert req1 in result
|
||||
assert req2 in result
|
||||
|
||||
# Filter and validate
|
||||
valid_items = integration_queue._validate_and_filter_requests(items)
|
||||
|
||||
assert len(valid_items) == 3
|
||||
assert req_ids[1] in integration_queue.canceled_req_ids
|
||||
assert all_ranks_num_active_requests[0] == 4
|
||||
|
||||
|
||||
@patch(
|
||||
'tensorrt_llm._torch.pyexecutor.executor_request_queue.executor_request_to_llm_request'
|
||||
)
|
||||
def test_merge_requests_with_beam_validation(mock_convert, integration_queue):
|
||||
"""Test request merging with beam width validation."""
|
||||
# Create mock requests with different beam widths
|
||||
mock_req1 = Mock()
|
||||
mock_req1.sampling_config = Mock()
|
||||
mock_req1.sampling_config.beam_width = 2 # Matches max_beam_width
|
||||
def test_schedule_attention_dp_requests_scheduled_requests_other_ranks(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
|
||||
attention_dp_relax=False))
|
||||
req2 = RequestQueueItem(
|
||||
2,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=2,
|
||||
attention_dp_relax=False))
|
||||
|
||||
mock_req2 = Mock()
|
||||
mock_req2.sampling_config = Mock()
|
||||
mock_req2.sampling_config.beam_width = 3 # Doesn't match max_beam_width
|
||||
new_requests = [req1, req2]
|
||||
|
||||
requests = [RequestQueueItem(1, mock_req1), RequestQueueItem(2, mock_req2)]
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
|
||||
# First request should pass validation
|
||||
valid_requests = integration_queue._validate_and_filter_requests(
|
||||
[requests[0]])
|
||||
assert len(valid_requests) == 1
|
||||
result = all_ranks_new_requests[0]
|
||||
assert len(result) == 0
|
||||
|
||||
# Second request should fail validation
|
||||
with pytest.raises(AssertionError):
|
||||
integration_queue._validate_and_filter_requests([requests[1]])
|
||||
assert all_ranks_num_active_requests[1] == 2
|
||||
assert all_ranks_num_active_requests[2] == 4
|
||||
|
||||
|
||||
def test_beam_width_validation_success(integration_queue):
|
||||
"""Test that beam width validation passes for correct beam width."""
|
||||
mock_req = Mock()
|
||||
mock_req.sampling_config = Mock()
|
||||
mock_req.sampling_config.beam_width = 2 # Matches integration test max_beam_width
|
||||
def test_schedule_attention_dp_requests_unscheduled_requests(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=True))
|
||||
req2 = RequestQueueItem(
|
||||
2,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
|
||||
attention_dp_relax=True))
|
||||
|
||||
request = RequestQueueItem(1, mock_req)
|
||||
valid_requests = integration_queue._validate_and_filter_requests([request])
|
||||
new_requests = [req1, req2]
|
||||
|
||||
assert len(valid_requests) == 1
|
||||
assert valid_requests[0] == request
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 1 # Only req1 for current rank
|
||||
assert req1 in result
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_unscheduled_no_capacity(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
all_ranks_num_active_requests[0] = 8
|
||||
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=True))
|
||||
|
||||
new_requests = [req1]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 0 # No capacity
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_mixed_scenarios(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
req_scheduled_current = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=False))
|
||||
req_scheduled_other = RequestQueueItem(
|
||||
2,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
|
||||
attention_dp_relax=False))
|
||||
req_unscheduled_current = RequestQueueItem(
|
||||
3,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=True))
|
||||
req_unscheduled_other = RequestQueueItem(
|
||||
4,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=2,
|
||||
attention_dp_relax=True))
|
||||
|
||||
new_requests = [
|
||||
req_scheduled_current, req_scheduled_other, req_unscheduled_current,
|
||||
req_unscheduled_other
|
||||
]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 2
|
||||
assert req_scheduled_current in result
|
||||
assert req_unscheduled_current in result
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_empty_lists(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
[], all_ranks_num_active_requests)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_expected_num_active_calculation(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=True))
|
||||
req2 = RequestQueueItem(
|
||||
2,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
|
||||
attention_dp_relax=True))
|
||||
|
||||
new_requests = [req1, req2]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
all_ranks_new_requests[0]
|
||||
|
||||
# 2 + 1 + 3 + 0 = 6, 6 + 2 = 8, (8 + 3) // 4 = 2, max(2, 2, 1, 3, 0) = 3
|
||||
assert attention_dp_queue.expected_num_active_requests == 3
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_balance_requests_called(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=True))
|
||||
|
||||
new_requests = [req1]
|
||||
|
||||
with patch.object(attention_dp_queue,
|
||||
'_balance_requests_across_ranks') as mock_balance:
|
||||
mock_balance.return_value = {0: req1}
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
all_ranks_new_requests[0]
|
||||
|
||||
# Check that _balance_requests_across_ranks was called
|
||||
mock_balance.assert_called_once()
|
||||
call_args = mock_balance.call_args[0]
|
||||
assert isinstance(call_args[0], list)
|
||||
assert isinstance(call_args[1], dict)
|
||||
assert call_args[2] == all_ranks_num_active_requests # Third arg
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_no_scheduling_when_capacity_exceeded(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
all_ranks_num_active_requests[0] = 8
|
||||
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=False))
|
||||
|
||||
new_requests = [req1]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 0 # No requests scheduled
|
||||
assert all_ranks_num_active_requests[0] == 8 # Capacity unchanged
|
||||
|
||||
|
||||
# Integration tests combining both methods
|
||||
def test_filter_and_schedule_integration(attention_dp_queue,
|
||||
all_ranks_num_active_requests):
|
||||
req_schedulable = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=False))
|
||||
req_schedulable.request.input_token_ids = [1, 2, 3, 4]
|
||||
req_relax = RequestQueueItem(
|
||||
2,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=True))
|
||||
req_relax.request.input_token_ids = [1, 2]
|
||||
|
||||
req_no_params = RequestQueueItem(
|
||||
3, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
|
||||
new_requests = [req_schedulable, req_relax, req_no_params]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 2
|
||||
assert req_schedulable in result
|
||||
assert req_relax in result
|
||||
|
||||
|
||||
def test_filter_and_schedule_with_capacity_limits(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
all_ranks_num_active_requests[0] = 7
|
||||
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=False))
|
||||
req1.request.input_token_ids = [1, 2, 3, 4]
|
||||
req2 = RequestQueueItem(
|
||||
2,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=False))
|
||||
req2.request.input_token_ids = [1, 2, 3]
|
||||
|
||||
new_requests = [req1, req2]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 1
|
||||
assert req1 in result
|
||||
|
||||
|
||||
def test_get_from_waiting_queue_with_attention_dp(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
items = [RequestQueueItem(i, Mock()) for i in range(5)]
|
||||
attention_dp_queue.waiting_queue.extend(items)
|
||||
|
||||
result = attention_dp_queue._get_from_waiting_queue(
|
||||
attention_dp_queue.waiting_queue, 3, True,
|
||||
all_ranks_num_active_requests)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result == items[:3]
|
||||
assert len(attention_dp_queue.waiting_queue) == 2
|
||||
|
||||
|
||||
def test_get_from_waiting_queue_with_attention_dp_filtering(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=False))
|
||||
req2 = RequestQueueItem(
|
||||
2,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
|
||||
attention_dp_relax=True))
|
||||
req3 = RequestQueueItem(3,
|
||||
create_mock_request_with_py_schedule_params(
|
||||
attention_dp_rank=None)) # No scheduling params
|
||||
|
||||
attention_dp_queue.waiting_queue.extend([req1, req2, req3])
|
||||
|
||||
# Set rank 0 to full capacity to test filtering
|
||||
all_ranks_num_active_requests[0] = 8
|
||||
|
||||
result = attention_dp_queue._get_from_waiting_queue(
|
||||
attention_dp_queue.waiting_queue, 3, True,
|
||||
all_ranks_num_active_requests)
|
||||
|
||||
assert len(result) == 2
|
||||
assert req2 in result
|
||||
assert req3 in result
|
||||
assert req1 not in result
|
||||
|
||||
|
||||
def test_can_process_attention_dp_request(attention_dp_queue):
|
||||
req_no_params = RequestQueueItem(1, Mock())
|
||||
assert attention_dp_queue._can_process_attention_dp_request(
|
||||
req_no_params, [0, 0, 0, 0]) == True
|
||||
|
||||
req_relax = RequestQueueItem(
|
||||
2,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=True))
|
||||
assert attention_dp_queue._can_process_attention_dp_request(
|
||||
req_relax, [0, 0, 0, 0]) == True
|
||||
|
||||
req_target = RequestQueueItem(
|
||||
3,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
|
||||
attention_dp_relax=False))
|
||||
all_ranks = [0, 0, 0, 0]
|
||||
assert attention_dp_queue._can_process_attention_dp_request(
|
||||
req_target, all_ranks) == True
|
||||
assert all_ranks[1] == 1
|
||||
|
||||
req_no_capacity = RequestQueueItem(
|
||||
4,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
attention_dp_relax=False))
|
||||
all_ranks_full = [8, 0, 0, 0] # Rank 0 is at capacity
|
||||
assert attention_dp_queue._can_process_attention_dp_request(
|
||||
req_no_capacity, all_ranks_full) == False
|
||||
|
||||
|
||||
def test_achieve_max_num_active_requests(attention_dp_queue):
|
||||
req_list = []
|
||||
req_id = 0
|
||||
for rank in range(4):
|
||||
for i in range(5):
|
||||
req_list.append(
|
||||
RequestQueueItem(
|
||||
req_id,
|
||||
create_mock_request_with_py_schedule_params(
|
||||
attention_dp_rank=rank, attention_dp_relax=False)))
|
||||
req_id += 1
|
||||
req_list.append(
|
||||
RequestQueueItem(
|
||||
req_id,
|
||||
create_mock_request_with_py_schedule_params(
|
||||
attention_dp_rank=rank, attention_dp_relax=True)))
|
||||
req_id += 1
|
||||
|
||||
all_ranks_num_active_requests = [5, 6, 3, 7]
|
||||
attention_dp_queue.waiting_queue.extend(req_list)
|
||||
available_active_requests = attention_dp_queue.max_num_active_requests * 4 - sum(
|
||||
all_ranks_num_active_requests)
|
||||
|
||||
result = attention_dp_queue._get_from_waiting_queue(
|
||||
attention_dp_queue.waiting_queue, available_active_requests, True,
|
||||
all_ranks_num_active_requests)
|
||||
|
||||
assert len(result) == available_active_requests
|
||||
|
||||
|
||||
def append_to_waiting_queue(waiting_queue, rank, attention_dp_relax):
|
||||
req_id = len(waiting_queue)
|
||||
waiting_queue.append(
|
||||
RequestQueueItem(
|
||||
req_id,
|
||||
create_mock_request_with_py_schedule_params(
|
||||
attention_dp_rank=rank, attention_dp_relax=attention_dp_relax)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_num_active_requests,all_ranks_num_active_requests,request_configs,all_ranks_expected_req_ids",
|
||||
[
|
||||
# Case: Balanced distribution of relaxed requests
|
||||
(3, [0, 0, 0, 0], [(None, True)] * 7, {
|
||||
0: [0, 4],
|
||||
1: [1, 5],
|
||||
2: [2, 6],
|
||||
3: [3]
|
||||
}),
|
||||
# Case: Balanced distribution of relaxed requests
|
||||
(3, [1, 2, 3, 0], [(None, True)] * 13, {
|
||||
0: [1, 4],
|
||||
1: [2],
|
||||
2: [],
|
||||
3: [0, 3, 5]
|
||||
}),
|
||||
# Case: Limited by max active
|
||||
(3, [0, 0, 0, 0], [(None, True)] * 13, {
|
||||
0: [0, 4, 8],
|
||||
1: [1, 5, 9],
|
||||
2: [2, 6, 10],
|
||||
3: [3, 7, 11]
|
||||
}),
|
||||
# Case: Empty new requests
|
||||
(3, [3, 3, 3, 0], [], {
|
||||
0: [],
|
||||
1: [],
|
||||
2: [],
|
||||
3: []
|
||||
}),
|
||||
# Case: Rank 0 is full and cannot schedule attention_dp rank request
|
||||
(3, [3, 1, 3, 0], [(0, False), (0, True)], {
|
||||
0: [],
|
||||
1: [],
|
||||
2: [],
|
||||
3: [1]
|
||||
}),
|
||||
# Case: Only room for 1 request, need to skip req0 with attention dp rank
|
||||
(3, [3, 2, 3, 3], [(0, False), (0, True)], {
|
||||
0: [],
|
||||
1: [1],
|
||||
2: [],
|
||||
3: []
|
||||
}),
|
||||
# Case: Targeting ranks 1 and 3 that have room
|
||||
(3, [2, 1, 3, 0], [(1, False), (3, False)], {
|
||||
0: [],
|
||||
1: [0],
|
||||
2: [],
|
||||
3: [1]
|
||||
}),
|
||||
# Case: Target dp rank specified, by relax is True
|
||||
(3, [3, 3, 3, 1], [(0, True), (1, True), (2, True)], {
|
||||
0: [],
|
||||
1: [],
|
||||
2: [],
|
||||
3: [0, 1]
|
||||
}),
|
||||
# Case:
|
||||
(3, [3, 3, 3, 0], [(0, False), (1, True), (3, False)], {
|
||||
0: [],
|
||||
1: [],
|
||||
2: [],
|
||||
3: [2, 1]
|
||||
}),
|
||||
])
|
||||
def test_attention_dp_scheduling_cases(attention_dp_queue,
|
||||
max_num_active_requests,
|
||||
all_ranks_num_active_requests,
|
||||
request_configs,
|
||||
all_ranks_expected_req_ids):
|
||||
"""Test attention DP scheduling with various scenarios."""
|
||||
attention_dp_queue.max_num_active_requests = max_num_active_requests
|
||||
|
||||
waiting_queue = deque()
|
||||
for rank, relax in request_configs:
|
||||
append_to_waiting_queue(waiting_queue, rank, relax)
|
||||
|
||||
run_test_attention_dp_scheduling(attention_dp_queue, waiting_queue,
|
||||
all_ranks_num_active_requests,
|
||||
all_ranks_expected_req_ids)
|
||||
|
||||
|
||||
def run_test_attention_dp_scheduling(attention_dp_queue, waiting_queue,
|
||||
all_ranks_num_active_requests,
|
||||
all_ranks_expected_req_ids):
|
||||
|
||||
num_ranks = len(all_ranks_num_active_requests)
|
||||
total_num_active_requests = sum(all_ranks_num_active_requests)
|
||||
total_max_num_active_requests = attention_dp_queue.max_num_active_requests * num_ranks
|
||||
enable_attention_dp = True
|
||||
|
||||
new_requests = attention_dp_queue._get_from_waiting_queue(
|
||||
waiting_queue,
|
||||
total_max_num_active_requests - total_num_active_requests,
|
||||
enable_attention_dp, all_ranks_num_active_requests)
|
||||
|
||||
# Schedule attention dp requests
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
|
||||
assert len(all_ranks_new_requests) == num_ranks
|
||||
print("all_ranks_new_requests:", all_ranks_new_requests)
|
||||
for rank, reqs in all_ranks_new_requests.items():
|
||||
req_ids = [req.id for req in reqs]
|
||||
assert req_ids == all_ranks_expected_req_ids[rank]
|
||||
|
||||
@ -156,6 +156,9 @@ methods:
|
||||
kv_cache_retention_config:
|
||||
annotation: Union[tensorrt_llm.bindings.executor.KvCacheRetentionConfig, Sequence[tensorrt_llm.bindings.executor.KvCacheRetentionConfig], NoneType]
|
||||
default: null
|
||||
scheduling_params:
|
||||
annotation: Union[tensorrt_llm.scheduling_params.SchedulingParams, List[tensorrt_llm.scheduling_params.SchedulingParams], NoneType]
|
||||
default: null
|
||||
return_annotation: Union[tensorrt_llm.llmapi.llm.RequestOutput, List[tensorrt_llm.llmapi.llm.RequestOutput]]
|
||||
generate_async:
|
||||
parameters:
|
||||
@ -165,6 +168,10 @@ methods:
|
||||
kv_cache_retention_config:
|
||||
annotation: Optional[tensorrt_llm.bindings.executor.KvCacheRetentionConfig]
|
||||
default: null
|
||||
scheduling_params:
|
||||
annotation: Optional[tensorrt_llm.scheduling_params.SchedulingParams]
|
||||
default: null
|
||||
status: prototype
|
||||
return_annotation: tensorrt_llm.llmapi.llm.RequestOutput
|
||||
get_kv_cache_events:
|
||||
parameters:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user