mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][opt] Balance the request based on number of tokens in AttentionDP (#7183)
Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co> Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
This commit is contained in:
parent
e12868bc00
commit
ff4047414b
@ -302,12 +302,13 @@ class ExecutorRequestQueue:
|
||||
return new_requests
|
||||
|
||||
@nvtx_range("_fetch_new_requests")
|
||||
def fetch_new_requests(self, num_active_requests: int) -> List[LlmRequest]:
|
||||
def fetch_new_requests(
|
||||
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
|
||||
|
||||
if self.enable_attention_dp:
|
||||
return self._fetch_new_requests_attention_dp(num_active_requests)
|
||||
return self._fetch_new_requests_attention_dp(activate_requests)
|
||||
else:
|
||||
return self._fetch_new_requests_attention_tp(num_active_requests)
|
||||
return self._fetch_new_requests_attention_tp(len(activate_requests))
|
||||
|
||||
def _fetch_new_requests_attention_tp(
|
||||
self, num_active_requests: int) -> List[LlmRequest]:
|
||||
@ -326,13 +327,18 @@ class ExecutorRequestQueue:
|
||||
return merged_requests
|
||||
|
||||
def _fetch_new_requests_attention_dp(
|
||||
self, num_active_requests: int) -> List[LlmRequest]:
|
||||
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
|
||||
"""Handle attention DP request fetching with load balancing."""
|
||||
# Get active request counts across all ranks
|
||||
all_ranks_num_active_requests = []
|
||||
responses_list = self.dist.tp_allgather(num_active_requests)
|
||||
for num_active_requests in responses_list:
|
||||
all_ranks_num_active_tokens = []
|
||||
num_active_tokens = sum(
|
||||
[req.py_orig_prompt_len for req in activate_requests])
|
||||
responses_list = self.dist.tp_allgather(
|
||||
[len(activate_requests), num_active_tokens])
|
||||
for num_active_requests, num_active_tokens in responses_list:
|
||||
all_ranks_num_active_requests.append(num_active_requests)
|
||||
all_ranks_num_active_tokens.append(num_active_tokens)
|
||||
|
||||
total_num_active_requests = sum(all_ranks_num_active_requests)
|
||||
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
|
||||
@ -346,7 +352,8 @@ class ExecutorRequestQueue:
|
||||
|
||||
# Schedule attention dp requests
|
||||
all_ranks_new_requests = self._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]
|
||||
|
||||
# Update performance metrics
|
||||
@ -364,7 +371,8 @@ class ExecutorRequestQueue:
|
||||
|
||||
def _schedule_attention_dp_requests(
|
||||
self, new_requests: List[RequestQueueItem],
|
||||
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
|
||||
all_ranks_num_active_requests: List[int],
|
||||
all_ranks_num_active_tokens: List[int]) -> List[RequestQueueItem]:
|
||||
"""Schedule attention dp requests."""
|
||||
|
||||
# Map from ranks to new requests
|
||||
@ -411,7 +419,7 @@ class ExecutorRequestQueue:
|
||||
|
||||
all_ranks_new_requests = self._balance_requests_across_ranks(
|
||||
remaining_unscheduled, all_ranks_new_requests,
|
||||
all_ranks_num_active_requests)
|
||||
all_ranks_num_active_requests, all_ranks_num_active_tokens)
|
||||
|
||||
return all_ranks_new_requests
|
||||
|
||||
@ -469,7 +477,8 @@ class ExecutorRequestQueue:
|
||||
def _balance_requests_across_ranks(
|
||||
self, new_requests: List[RequestQueueItem],
|
||||
all_ranks_new_requests: Dict[int, List[RequestQueueItem]],
|
||||
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
|
||||
all_ranks_num_active_requests: List[int],
|
||||
all_ranks_num_active_tokens: List[int]) -> List[RequestQueueItem]:
|
||||
"""Balance requests across ranks for attention DP."""
|
||||
if new_requests:
|
||||
# Balance context tokens across ranks using heap
|
||||
@ -478,7 +487,7 @@ class ExecutorRequestQueue:
|
||||
['num_tokens', 'num_requests', 'rank', 'request_list'])
|
||||
|
||||
all_ranks_new_requests_heap = [
|
||||
HeapVal(0, val, tp_rank, [])
|
||||
HeapVal(all_ranks_num_active_tokens[tp_rank], val, tp_rank, [])
|
||||
for tp_rank, val in enumerate(all_ranks_num_active_requests)
|
||||
]
|
||||
|
||||
@ -503,6 +512,7 @@ class ExecutorRequestQueue:
|
||||
|
||||
# Distribute requests across ranks
|
||||
for req_item in new_requests:
|
||||
|
||||
val = heapq.heappop(all_ranks_new_requests_heap)
|
||||
token_count = len(
|
||||
getattr(req_item.request, 'input_token_ids',
|
||||
|
||||
@ -1185,7 +1185,7 @@ class PyExecutor:
|
||||
return True
|
||||
|
||||
new_requests_cur_rank = self.executor_request_queue.fetch_new_requests(
|
||||
len(self.active_requests))
|
||||
self.active_requests)
|
||||
self.is_shutdown = self.executor_request_queue.is_shutdown
|
||||
self.expected_num_active_requests = self.executor_request_queue.get_expected_num_active_requests(
|
||||
)
|
||||
|
||||
@ -453,6 +453,11 @@ def all_ranks_num_active_requests():
|
||||
return [2, 1, 3, 0] # 4 ranks
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_ranks_num_active_tokens():
|
||||
return [10, 5, 15, 8] # 4 ranks
|
||||
|
||||
|
||||
def create_mock_request_with_py_schedule_params(attention_dp_rank=None,
|
||||
attention_dp_relax=False):
|
||||
mock_request = Mock()
|
||||
@ -477,7 +482,8 @@ def create_mock_request_with_py_schedule_params(attention_dp_rank=None,
|
||||
|
||||
# Unit tests for _schedule_attention_dp_requests
|
||||
def test_schedule_attention_dp_requests_scheduled_requests(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
attention_dp_queue, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
@ -490,7 +496,8 @@ def test_schedule_attention_dp_requests_scheduled_requests(
|
||||
new_requests = [req1, req2]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 2
|
||||
@ -501,7 +508,8 @@ def test_schedule_attention_dp_requests_scheduled_requests(
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_scheduled_requests_other_ranks(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
attention_dp_queue, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
|
||||
@ -514,7 +522,8 @@ def test_schedule_attention_dp_requests_scheduled_requests_other_ranks(
|
||||
new_requests = [req1, req2]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
|
||||
result = all_ranks_new_requests[0]
|
||||
assert len(result) == 0
|
||||
@ -524,7 +533,8 @@ def test_schedule_attention_dp_requests_scheduled_requests_other_ranks(
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_unscheduled_requests(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
attention_dp_queue, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
@ -537,7 +547,8 @@ def test_schedule_attention_dp_requests_unscheduled_requests(
|
||||
new_requests = [req1, req2]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 1 # Only req1 for current rank
|
||||
@ -545,7 +556,8 @@ def test_schedule_attention_dp_requests_unscheduled_requests(
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_unscheduled_no_capacity(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
attention_dp_queue, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
all_ranks_num_active_requests[0] = 8
|
||||
|
||||
req1 = RequestQueueItem(
|
||||
@ -556,14 +568,16 @@ def test_schedule_attention_dp_requests_unscheduled_no_capacity(
|
||||
new_requests = [req1]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 0 # No capacity
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_mixed_scenarios(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
attention_dp_queue, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
req_scheduled_current = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
@ -587,7 +601,8 @@ def test_schedule_attention_dp_requests_mixed_scenarios(
|
||||
]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 2
|
||||
@ -596,16 +611,18 @@ def test_schedule_attention_dp_requests_mixed_scenarios(
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_empty_lists(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
attention_dp_queue, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
[], all_ranks_num_active_requests)
|
||||
[], all_ranks_num_active_requests, all_ranks_num_active_tokens)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_expected_num_active_calculation(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
attention_dp_queue, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
@ -618,15 +635,18 @@ def test_schedule_attention_dp_requests_expected_num_active_calculation(
|
||||
new_requests = [req1, req2]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
all_ranks_new_requests[0]
|
||||
|
||||
# 2 + 1 + 3 + 0 = 6, 6 + 2 = 8, (8 + 3) // 4 = 2, max(2, 2, 1, 3, 0) = 3
|
||||
# expected_num_active_requests = max((6 + 2 + 3) // 4, 3) = max(2, 3) = 3
|
||||
assert attention_dp_queue.expected_num_active_requests == 3
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_balance_requests_called(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
attention_dp_queue, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
req1 = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
@ -639,7 +659,8 @@ def test_schedule_attention_dp_requests_balance_requests_called(
|
||||
mock_balance.return_value = {0: req1}
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
all_ranks_new_requests[0]
|
||||
|
||||
# Check that _balance_requests_across_ranks was called
|
||||
@ -648,10 +669,12 @@ def test_schedule_attention_dp_requests_balance_requests_called(
|
||||
assert isinstance(call_args[0], list)
|
||||
assert isinstance(call_args[1], dict)
|
||||
assert call_args[2] == all_ranks_num_active_requests # Third arg
|
||||
assert call_args[3] == all_ranks_num_active_tokens # Fourth arg
|
||||
|
||||
|
||||
def test_schedule_attention_dp_requests_no_scheduling_when_capacity_exceeded(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
attention_dp_queue, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
all_ranks_num_active_requests[0] = 8
|
||||
|
||||
req1 = RequestQueueItem(
|
||||
@ -662,7 +685,8 @@ def test_schedule_attention_dp_requests_no_scheduling_when_capacity_exceeded(
|
||||
new_requests = [req1]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 0 # No requests scheduled
|
||||
@ -671,7 +695,8 @@ def test_schedule_attention_dp_requests_no_scheduling_when_capacity_exceeded(
|
||||
|
||||
# Integration tests combining both methods
|
||||
def test_filter_and_schedule_integration(attention_dp_queue,
|
||||
all_ranks_num_active_requests):
|
||||
all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
req_schedulable = RequestQueueItem(
|
||||
1,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
|
||||
@ -689,7 +714,8 @@ def test_filter_and_schedule_integration(attention_dp_queue,
|
||||
new_requests = [req_schedulable, req_relax, req_no_params]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 2
|
||||
@ -697,8 +723,9 @@ def test_filter_and_schedule_integration(attention_dp_queue,
|
||||
assert req_relax in result
|
||||
|
||||
|
||||
def test_filter_and_schedule_with_capacity_limits(
|
||||
attention_dp_queue, all_ranks_num_active_requests):
|
||||
def test_filter_and_schedule_with_capacity_limits(attention_dp_queue,
|
||||
all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens):
|
||||
all_ranks_num_active_requests[0] = 7
|
||||
|
||||
req1 = RequestQueueItem(
|
||||
@ -715,7 +742,8 @@ def test_filter_and_schedule_with_capacity_limits(
|
||||
new_requests = [req1, req2]
|
||||
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
result = all_ranks_new_requests[0]
|
||||
|
||||
assert len(result) == 1
|
||||
@ -838,26 +866,38 @@ def append_to_waiting_queue(waiting_queue, rank, attention_dp_relax):
|
||||
"max_num_active_requests,all_ranks_num_active_requests,request_configs,all_ranks_expected_req_ids",
|
||||
[
|
||||
# Case: Balanced distribution of relaxed requests
|
||||
(3, [0, 0, 0, 0], [(None, True)] * 7, {
|
||||
0: [0, 4],
|
||||
1: [1, 5],
|
||||
2: [2, 6],
|
||||
3: [3]
|
||||
}),
|
||||
# Case: Balanced distribution of relaxed requests
|
||||
(3, [1, 2, 3, 0], [(None, True)] * 13, {
|
||||
0: [1, 4],
|
||||
1: [2],
|
||||
2: [],
|
||||
3: [0, 3, 5]
|
||||
}),
|
||||
(
|
||||
3,
|
||||
[0, 0, 0, 0],
|
||||
[(None, True)] * 7,
|
||||
{
|
||||
0: [0, 1], # First 2 requests go to rank 0
|
||||
1: [2, 3], # Next 2 requests go to rank 1
|
||||
2: [4, 5], # Next 2 requests go to rank 2
|
||||
3: [6] # Last request goes to rank 3
|
||||
}),
|
||||
# Case: Balanced distribution of relaxed requests with existing load
|
||||
(
|
||||
3,
|
||||
[1, 2, 3, 0],
|
||||
[(None, True)] * 13,
|
||||
{
|
||||
0: [0, 1], # Rank 0 gets first 2 requests
|
||||
1: [2], # Rank 1 gets 1 request (already has 2)
|
||||
2: [], # Rank 2 is at capacity (3)
|
||||
3: [3, 4, 5] # Rank 3 gets 3 requests (starts with 0)
|
||||
}),
|
||||
# Case: Limited by max active
|
||||
(3, [0, 0, 0, 0], [(None, True)] * 13, {
|
||||
0: [0, 4, 8],
|
||||
1: [1, 5, 9],
|
||||
2: [2, 6, 10],
|
||||
3: [3, 7, 11]
|
||||
}),
|
||||
(
|
||||
3,
|
||||
[0, 0, 0, 0],
|
||||
[(None, True)] * 13,
|
||||
{
|
||||
0: [0, 1, 3], # First 3 requests (0, 1, 3)
|
||||
1: [2, 4, 6], # Next 3 requests (2, 4, 6)
|
||||
2: [5, 7, 9], # Next 3 requests (5, 7, 9)
|
||||
3: [8, 10, 11] # Last 3 requests (8, 10, 11)
|
||||
}),
|
||||
# Case: Empty new requests
|
||||
(3, [3, 3, 3, 0], [], {
|
||||
0: [],
|
||||
@ -866,40 +906,60 @@ def append_to_waiting_queue(waiting_queue, rank, attention_dp_relax):
|
||||
3: []
|
||||
}),
|
||||
# Case: Rank 0 is full and cannot schedule attention_dp rank request
|
||||
(3, [3, 1, 3, 0], [(0, False), (0, True)], {
|
||||
0: [],
|
||||
1: [],
|
||||
2: [],
|
||||
3: [1]
|
||||
}),
|
||||
(
|
||||
3,
|
||||
[3, 1, 3, 0],
|
||||
[(0, False), (0, True)],
|
||||
{
|
||||
0: [], # Rank 0 is full
|
||||
1: [1], # Rank 1 gets the relaxed request (req1)
|
||||
2: [], # No relaxed requests assigned here
|
||||
3: [] # No relaxed requests assigned here
|
||||
}),
|
||||
# Case: Only room for 1 request, need to skip req0 with attention dp rank
|
||||
(3, [3, 2, 3, 3], [(0, False), (0, True)], {
|
||||
0: [],
|
||||
1: [1],
|
||||
2: [],
|
||||
3: []
|
||||
}),
|
||||
(
|
||||
3,
|
||||
[3, 2, 3, 3],
|
||||
[(0, False), (0, True)],
|
||||
{
|
||||
0: [], # Rank 0 is full
|
||||
1: [1], # Rank 1 gets the relaxed request
|
||||
2: [], # Rank 2 is at capacity
|
||||
3: [] # Rank 3 is at capacity
|
||||
}),
|
||||
# Case: Targeting ranks 1 and 3 that have room
|
||||
(3, [2, 1, 3, 0], [(1, False), (3, False)], {
|
||||
0: [],
|
||||
1: [0],
|
||||
2: [],
|
||||
3: [1]
|
||||
}),
|
||||
# Case: Target dp rank specified, by relax is True
|
||||
(3, [3, 3, 3, 1], [(0, True), (1, True), (2, True)], {
|
||||
0: [],
|
||||
1: [],
|
||||
2: [],
|
||||
3: [0, 1]
|
||||
}),
|
||||
# Case:
|
||||
(3, [3, 3, 3, 0], [(0, False), (1, True), (3, False)], {
|
||||
0: [],
|
||||
1: [],
|
||||
2: [],
|
||||
3: [2, 1]
|
||||
}),
|
||||
(
|
||||
3,
|
||||
[2, 1, 3, 0],
|
||||
[(1, False), (3, False)],
|
||||
{
|
||||
0: [], # No requests assigned to rank 0
|
||||
1: [0], # Request 0 targets rank 1
|
||||
2: [], # No requests assigned to rank 2
|
||||
3: [1] # Request 1 targets rank 3
|
||||
}),
|
||||
# Case: Target dp rank specified, but relax is True
|
||||
(
|
||||
3,
|
||||
[3, 3, 3, 1],
|
||||
[(0, True), (1, True), (2, True)],
|
||||
{
|
||||
0: [], # Rank 0 is at capacity
|
||||
1: [], # Rank 1 is at capacity
|
||||
2: [], # Rank 2 is at capacity
|
||||
3: [0, 1] # Rank 3 gets both relaxed requests
|
||||
}),
|
||||
# Case: Mixed targeting and relaxed
|
||||
(
|
||||
3,
|
||||
[3, 3, 3, 0],
|
||||
[(0, False), (1, True), (3, False)],
|
||||
{
|
||||
0: [], # Rank 0 is at capacity
|
||||
1: [], # Rank 1 is at capacity
|
||||
2: [], # Rank 2 is at capacity
|
||||
3: [2, 1] # Rank 3 gets both requests (targeted + relaxed)
|
||||
}),
|
||||
])
|
||||
def test_attention_dp_scheduling_cases(attention_dp_queue,
|
||||
max_num_active_requests,
|
||||
@ -932,12 +992,226 @@ def run_test_attention_dp_scheduling(attention_dp_queue, waiting_queue,
|
||||
total_max_num_active_requests - total_num_active_requests,
|
||||
enable_attention_dp, all_ranks_num_active_requests)
|
||||
|
||||
# Create mock token counts for testing
|
||||
all_ranks_num_active_tokens = [10 + i * 5 for i in range(num_ranks)]
|
||||
|
||||
# Schedule attention dp requests
|
||||
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
|
||||
new_requests, all_ranks_num_active_requests)
|
||||
new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
|
||||
assert len(all_ranks_new_requests) == num_ranks
|
||||
print("all_ranks_new_requests:", all_ranks_new_requests)
|
||||
for rank, reqs in all_ranks_new_requests.items():
|
||||
req_ids = [req.id for req in reqs]
|
||||
assert req_ids == all_ranks_expected_req_ids[rank]
|
||||
|
||||
|
||||
# New tests for _balance_requests_across_ranks method
|
||||
def test_balance_requests_across_ranks_empty_requests(attention_dp_queue):
|
||||
"""Test _balance_requests_across_ranks with empty requests list."""
|
||||
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
|
||||
all_ranks_num_active_requests = [2, 1, 3, 0]
|
||||
all_ranks_num_active_tokens = [20, 10, 30, 5]
|
||||
|
||||
# Set expected_num_active_requests for testing
|
||||
attention_dp_queue.expected_num_active_requests = 3
|
||||
|
||||
result = attention_dp_queue._balance_requests_across_ranks(
|
||||
[], all_ranks_new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
|
||||
# Should return the original structure unchanged
|
||||
assert result == all_ranks_new_requests
|
||||
for rank in range(4):
|
||||
assert len(result[rank]) == 0
|
||||
|
||||
|
||||
def test_balance_requests_across_ranks_single_request(attention_dp_queue):
|
||||
"""Test _balance_requests_across_ranks with a single request."""
|
||||
req = RequestQueueItem(
|
||||
1, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req.request.input_token_ids = [1, 2, 3, 4, 5] # 5 tokens
|
||||
|
||||
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
|
||||
all_ranks_num_active_requests = [1, 2, 0, 1] # Rank 2 has lowest count
|
||||
all_ranks_num_active_tokens = [10, 20, 5, 15]
|
||||
|
||||
# Set expected_num_active_requests for testing
|
||||
attention_dp_queue.expected_num_active_requests = 2
|
||||
|
||||
result = attention_dp_queue._balance_requests_across_ranks(
|
||||
[req], all_ranks_new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
|
||||
# Request should be assigned to rank 2 (lowest active count)
|
||||
assert len(result[0]) == 0
|
||||
assert len(result[1]) == 0
|
||||
assert len(result[2]) == 1
|
||||
assert len(result[3]) == 0
|
||||
assert result[2][0] == req
|
||||
|
||||
|
||||
def test_balance_requests_across_ranks_multiple_requests(attention_dp_queue):
|
||||
"""Test _balance_requests_across_ranks with multiple requests."""
|
||||
# Create requests with different token counts
|
||||
req1 = RequestQueueItem(
|
||||
1, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req1.request.input_token_ids = [1, 2, 3] # 3 tokens
|
||||
|
||||
req2 = RequestQueueItem(
|
||||
2, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req2.request.input_token_ids = [1, 2, 3, 4, 5, 6] # 6 tokens
|
||||
|
||||
req3 = RequestQueueItem(
|
||||
3, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req3.request.input_token_ids = [1, 2] # 2 tokens
|
||||
|
||||
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
|
||||
all_ranks_num_active_requests = [0, 1, 2, 1]
|
||||
all_ranks_num_active_tokens = [5, 15, 25, 10]
|
||||
|
||||
# Set expected_num_active_requests for testing
|
||||
attention_dp_queue.expected_num_active_requests = 2
|
||||
|
||||
result = attention_dp_queue._balance_requests_across_ranks(
|
||||
[req1, req2, req3], all_ranks_new_requests,
|
||||
all_ranks_num_active_requests, all_ranks_num_active_tokens)
|
||||
|
||||
# Requests should be distributed based on heap (lowest active count first)
|
||||
# Requests are sorted by token count (descending) first, then assigned to ranks with lowest active count
|
||||
# req2 (6 tokens) -> rank 0 (0 active) -> total: 1 active, 11 tokens
|
||||
# req3 (2 tokens) -> rank 0 (1 active) -> total: 2 active, 13 tokens (rank 0 still has capacity)
|
||||
# req1 (3 tokens) -> rank 3 (1 active) -> total: 2 active, 13 tokens
|
||||
# Rank 1: 1 active, gets nothing (rank 0 took 2 requests)
|
||||
# Rank 2: 2 active, gets nothing (at capacity)
|
||||
|
||||
assert len(result[0]) == 2 # req2 and req3 (rank 0 has capacity for 2)
|
||||
assert len(result[1]) == 0 # no requests (rank 0 took 2 requests)
|
||||
assert len(result[2]) == 0 # at capacity
|
||||
assert len(result[3]) == 1 # req1
|
||||
|
||||
# Verify the requests are assigned correctly
|
||||
assert result[0][0] == req2 # First request (highest token count)
|
||||
assert result[0][1] == req3 # Second request
|
||||
assert result[3][0] == req1
|
||||
|
||||
|
||||
def test_balance_requests_across_ranks_capacity_limits(attention_dp_queue):
|
||||
"""Test _balance_requests_across_ranks respects capacity limits."""
|
||||
# Create multiple requests
|
||||
requests = []
|
||||
for i in range(4):
|
||||
req = RequestQueueItem(
|
||||
i,
|
||||
create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req.request.input_token_ids = [1] * (i + 1) # Variable token counts
|
||||
requests.append(req)
|
||||
|
||||
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
|
||||
all_ranks_num_active_requests = [1, 1, 1, 1] # All ranks start with 1
|
||||
all_ranks_num_active_tokens = [10, 10, 10, 10]
|
||||
|
||||
# Set expected_num_active_requests to limit capacity
|
||||
attention_dp_queue.expected_num_active_requests = 2
|
||||
|
||||
result = attention_dp_queue._balance_requests_across_ranks(
|
||||
requests, all_ranks_new_requests, all_ranks_num_active_requests,
|
||||
all_ranks_num_active_tokens)
|
||||
|
||||
# Each rank can only take 1 more request (1 + 1 = 2, which equals expected_num_active_requests)
|
||||
total_assigned = sum(
|
||||
len(rank_requests) for rank_requests in result.values())
|
||||
assert total_assigned == 4 # 4 ranks × 1 additional request each
|
||||
|
||||
# Verify no rank exceeds capacity
|
||||
for rank in range(4):
|
||||
assert len(result[rank]) <= 1
|
||||
|
||||
|
||||
def test_balance_requests_across_ranks_heap_ordering(attention_dp_queue):
|
||||
"""Test that _balance_requests_across_ranks uses heap ordering correctly."""
|
||||
# Create requests with same token count to test heap ordering
|
||||
req1 = RequestQueueItem(
|
||||
1, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req1.request.input_token_ids = [1, 2, 3] # 3 tokens
|
||||
|
||||
req2 = RequestQueueItem(
|
||||
2, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req2.request.input_token_ids = [1, 2, 3] # 3 tokens
|
||||
|
||||
req3 = RequestQueueItem(
|
||||
3, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req3.request.input_token_ids = [1, 2, 3] # 3 tokens
|
||||
|
||||
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
|
||||
# Rank 0 has highest active count, should get requests last
|
||||
all_ranks_num_active_requests = [3, 1, 0, 2]
|
||||
all_ranks_num_active_tokens = [30, 10, 5, 20]
|
||||
|
||||
# Set expected_num_active_requests for testing
|
||||
attention_dp_queue.expected_num_active_requests = 4
|
||||
|
||||
result = attention_dp_queue._balance_requests_across_ranks(
|
||||
[req1, req2, req3], all_ranks_new_requests,
|
||||
all_ranks_num_active_requests, all_ranks_num_active_tokens)
|
||||
|
||||
# Requests should be assigned in order of lowest active count first
|
||||
# Since all requests have same token count, they're assigned based on active count order
|
||||
# Rank 2: 0 active -> gets req1 and req2 (has capacity for 2)
|
||||
# Rank 1: 1 active -> gets req3 (after rank 2 takes 2)
|
||||
# Rank 3: 2 active -> gets nothing (rank 1 took req3)
|
||||
# Rank 0: 3 active -> gets nothing (at capacity)
|
||||
|
||||
assert len(result[0]) == 0 # at capacity
|
||||
assert len(result[1]) == 1 # req3
|
||||
assert len(result[2]) == 2 # req1 and req2
|
||||
assert len(result[3]) == 0 # no requests
|
||||
|
||||
# Verify the requests are assigned correctly
|
||||
assert result[1][0] == req3 # Third request
|
||||
assert result[2][0] == req1 # First request
|
||||
assert result[2][1] == req2 # Second request
|
||||
|
||||
|
||||
def test_balance_requests_across_ranks_token_count_sorting(attention_dp_queue):
|
||||
"""Test that requests are sorted by token count before distribution."""
|
||||
# Create requests with different token counts
|
||||
req1 = RequestQueueItem(
|
||||
1, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req1.request.input_token_ids = [1] # 1 token (smallest)
|
||||
|
||||
req2 = RequestQueueItem(
|
||||
2, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req2.request.input_token_ids = [1, 2, 3, 4, 5] # 5 tokens (largest)
|
||||
|
||||
req3 = RequestQueueItem(
|
||||
3, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
|
||||
req3.request.input_token_ids = [1, 2, 3] # 3 tokens (medium)
|
||||
|
||||
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
|
||||
all_ranks_num_active_requests = [0, 0, 0, 0] # All ranks start empty
|
||||
all_ranks_num_active_tokens = [5, 5, 5, 5]
|
||||
|
||||
# Set expected_num_active_requests for testing
|
||||
attention_dp_queue.expected_num_active_requests = 2
|
||||
|
||||
result = attention_dp_queue._balance_requests_across_ranks(
|
||||
[req1, req2, req3], all_ranks_new_requests,
|
||||
all_ranks_num_active_requests, all_ranks_num_active_tokens)
|
||||
|
||||
# Requests should be sorted by token count (descending) before distribution
|
||||
# Then assigned to ranks with lowest active count first
|
||||
# req2 (5 tokens) -> rank 0 (0 active)
|
||||
# req3 (3 tokens) -> rank 1 (0 active)
|
||||
# req1 (1 token) -> rank 2 (0 active)
|
||||
|
||||
assert len(result[0]) == 1 # req2 (highest token count)
|
||||
assert len(result[1]) == 1 # req3
|
||||
assert len(result[2]) == 1 # req1 (lowest token count)
|
||||
assert len(result[3]) == 0
|
||||
|
||||
# Verify the requests are assigned correctly
|
||||
assert result[0][0] == req2
|
||||
assert result[1][0] == req3
|
||||
assert result[2][0] == req1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user