[None][opt] Balance the request based on number of tokens in AttentionDP (#7183)

Signed-off-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
Co-authored-by: Shunkang <182541032+Shunkangz@users.noreply.github.co>
This commit is contained in:
Shunkangz 2025-08-27 11:16:12 +08:00 committed by GitHub
parent e12868bc00
commit ff4047414b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 371 additions and 87 deletions

View File

@ -302,12 +302,13 @@ class ExecutorRequestQueue:
return new_requests
@nvtx_range("_fetch_new_requests")
def fetch_new_requests(self, num_active_requests: int) -> List[LlmRequest]:
def fetch_new_requests(
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
if self.enable_attention_dp:
return self._fetch_new_requests_attention_dp(num_active_requests)
return self._fetch_new_requests_attention_dp(activate_requests)
else:
return self._fetch_new_requests_attention_tp(num_active_requests)
return self._fetch_new_requests_attention_tp(len(activate_requests))
def _fetch_new_requests_attention_tp(
self, num_active_requests: int) -> List[LlmRequest]:
@ -326,13 +327,18 @@ class ExecutorRequestQueue:
return merged_requests
def _fetch_new_requests_attention_dp(
self, num_active_requests: int) -> List[LlmRequest]:
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
"""Handle attention DP request fetching with load balancing."""
# Get active request counts across all ranks
all_ranks_num_active_requests = []
responses_list = self.dist.tp_allgather(num_active_requests)
for num_active_requests in responses_list:
all_ranks_num_active_tokens = []
num_active_tokens = sum(
[req.py_orig_prompt_len for req in activate_requests])
responses_list = self.dist.tp_allgather(
[len(activate_requests), num_active_tokens])
for num_active_requests, num_active_tokens in responses_list:
all_ranks_num_active_requests.append(num_active_requests)
all_ranks_num_active_tokens.append(num_active_tokens)
total_num_active_requests = sum(all_ranks_num_active_requests)
total_max_num_active_requests = self.dist.tp_size * self.max_num_active_requests
@ -346,7 +352,8 @@ class ExecutorRequestQueue:
# Schedule attention dp requests
all_ranks_new_requests = self._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
new_requests_cur_rank = all_ranks_new_requests[self.dist.tp_rank]
# Update performance metrics
@ -364,7 +371,8 @@ class ExecutorRequestQueue:
def _schedule_attention_dp_requests(
self, new_requests: List[RequestQueueItem],
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
all_ranks_num_active_requests: List[int],
all_ranks_num_active_tokens: List[int]) -> List[RequestQueueItem]:
"""Schedule attention dp requests."""
# Map from ranks to new requests
@ -411,7 +419,7 @@ class ExecutorRequestQueue:
all_ranks_new_requests = self._balance_requests_across_ranks(
remaining_unscheduled, all_ranks_new_requests,
all_ranks_num_active_requests)
all_ranks_num_active_requests, all_ranks_num_active_tokens)
return all_ranks_new_requests
@ -469,7 +477,8 @@ class ExecutorRequestQueue:
def _balance_requests_across_ranks(
self, new_requests: List[RequestQueueItem],
all_ranks_new_requests: Dict[int, List[RequestQueueItem]],
all_ranks_num_active_requests: List[int]) -> List[RequestQueueItem]:
all_ranks_num_active_requests: List[int],
all_ranks_num_active_tokens: List[int]) -> List[RequestQueueItem]:
"""Balance requests across ranks for attention DP."""
if new_requests:
# Balance context tokens across ranks using heap
@ -478,7 +487,7 @@ class ExecutorRequestQueue:
['num_tokens', 'num_requests', 'rank', 'request_list'])
all_ranks_new_requests_heap = [
HeapVal(0, val, tp_rank, [])
HeapVal(all_ranks_num_active_tokens[tp_rank], val, tp_rank, [])
for tp_rank, val in enumerate(all_ranks_num_active_requests)
]
@ -503,6 +512,7 @@ class ExecutorRequestQueue:
# Distribute requests across ranks
for req_item in new_requests:
val = heapq.heappop(all_ranks_new_requests_heap)
token_count = len(
getattr(req_item.request, 'input_token_ids',

View File

@ -1185,7 +1185,7 @@ class PyExecutor:
return True
new_requests_cur_rank = self.executor_request_queue.fetch_new_requests(
len(self.active_requests))
self.active_requests)
self.is_shutdown = self.executor_request_queue.is_shutdown
self.expected_num_active_requests = self.executor_request_queue.get_expected_num_active_requests(
)

View File

@ -453,6 +453,11 @@ def all_ranks_num_active_requests():
return [2, 1, 3, 0] # 4 ranks
@pytest.fixture
def all_ranks_num_active_tokens():
return [10, 5, 15, 8] # 4 ranks
def create_mock_request_with_py_schedule_params(attention_dp_rank=None,
attention_dp_relax=False):
mock_request = Mock()
@ -477,7 +482,8 @@ def create_mock_request_with_py_schedule_params(attention_dp_rank=None,
# Unit tests for _schedule_attention_dp_requests
def test_schedule_attention_dp_requests_scheduled_requests(
attention_dp_queue, all_ranks_num_active_requests):
attention_dp_queue, all_ranks_num_active_requests,
all_ranks_num_active_tokens):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
@ -490,7 +496,8 @@ def test_schedule_attention_dp_requests_scheduled_requests(
new_requests = [req1, req2]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
result = all_ranks_new_requests[0]
assert len(result) == 2
@ -501,7 +508,8 @@ def test_schedule_attention_dp_requests_scheduled_requests(
def test_schedule_attention_dp_requests_scheduled_requests_other_ranks(
attention_dp_queue, all_ranks_num_active_requests):
attention_dp_queue, all_ranks_num_active_requests,
all_ranks_num_active_tokens):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=1,
@ -514,7 +522,8 @@ def test_schedule_attention_dp_requests_scheduled_requests_other_ranks(
new_requests = [req1, req2]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
result = all_ranks_new_requests[0]
assert len(result) == 0
@ -524,7 +533,8 @@ def test_schedule_attention_dp_requests_scheduled_requests_other_ranks(
def test_schedule_attention_dp_requests_unscheduled_requests(
attention_dp_queue, all_ranks_num_active_requests):
attention_dp_queue, all_ranks_num_active_requests,
all_ranks_num_active_tokens):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
@ -537,7 +547,8 @@ def test_schedule_attention_dp_requests_unscheduled_requests(
new_requests = [req1, req2]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
result = all_ranks_new_requests[0]
assert len(result) == 1 # Only req1 for current rank
@ -545,7 +556,8 @@ def test_schedule_attention_dp_requests_unscheduled_requests(
def test_schedule_attention_dp_requests_unscheduled_no_capacity(
attention_dp_queue, all_ranks_num_active_requests):
attention_dp_queue, all_ranks_num_active_requests,
all_ranks_num_active_tokens):
all_ranks_num_active_requests[0] = 8
req1 = RequestQueueItem(
@ -556,14 +568,16 @@ def test_schedule_attention_dp_requests_unscheduled_no_capacity(
new_requests = [req1]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
result = all_ranks_new_requests[0]
assert len(result) == 0 # No capacity
def test_schedule_attention_dp_requests_mixed_scenarios(
attention_dp_queue, all_ranks_num_active_requests):
attention_dp_queue, all_ranks_num_active_requests,
all_ranks_num_active_tokens):
req_scheduled_current = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
@ -587,7 +601,8 @@ def test_schedule_attention_dp_requests_mixed_scenarios(
]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
result = all_ranks_new_requests[0]
assert len(result) == 2
@ -596,16 +611,18 @@ def test_schedule_attention_dp_requests_mixed_scenarios(
def test_schedule_attention_dp_requests_empty_lists(
attention_dp_queue, all_ranks_num_active_requests):
attention_dp_queue, all_ranks_num_active_requests,
all_ranks_num_active_tokens):
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
[], all_ranks_num_active_requests)
[], all_ranks_num_active_requests, all_ranks_num_active_tokens)
result = all_ranks_new_requests[0]
assert len(result) == 0
def test_schedule_attention_dp_requests_expected_num_active_calculation(
attention_dp_queue, all_ranks_num_active_requests):
attention_dp_queue, all_ranks_num_active_requests,
all_ranks_num_active_tokens):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
@ -618,15 +635,18 @@ def test_schedule_attention_dp_requests_expected_num_active_calculation(
new_requests = [req1, req2]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
all_ranks_new_requests[0]
# 2 + 1 + 3 + 0 = 6, 6 + 2 = 8, (8 + 3) // 4 = 2, max(2, 2, 1, 3, 0) = 3
# expected_num_active_requests = max((6 + 2 + 3) // 4, 3) = max(2, 3) = 3
assert attention_dp_queue.expected_num_active_requests == 3
def test_schedule_attention_dp_requests_balance_requests_called(
attention_dp_queue, all_ranks_num_active_requests):
attention_dp_queue, all_ranks_num_active_requests,
all_ranks_num_active_tokens):
req1 = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
@ -639,7 +659,8 @@ def test_schedule_attention_dp_requests_balance_requests_called(
mock_balance.return_value = {0: req1}
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
all_ranks_new_requests[0]
# Check that _balance_requests_across_ranks was called
@ -648,10 +669,12 @@ def test_schedule_attention_dp_requests_balance_requests_called(
assert isinstance(call_args[0], list)
assert isinstance(call_args[1], dict)
assert call_args[2] == all_ranks_num_active_requests # Third arg
assert call_args[3] == all_ranks_num_active_tokens # Fourth arg
def test_schedule_attention_dp_requests_no_scheduling_when_capacity_exceeded(
attention_dp_queue, all_ranks_num_active_requests):
attention_dp_queue, all_ranks_num_active_requests,
all_ranks_num_active_tokens):
all_ranks_num_active_requests[0] = 8
req1 = RequestQueueItem(
@ -662,7 +685,8 @@ def test_schedule_attention_dp_requests_no_scheduling_when_capacity_exceeded(
new_requests = [req1]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
result = all_ranks_new_requests[0]
assert len(result) == 0 # No requests scheduled
@ -671,7 +695,8 @@ def test_schedule_attention_dp_requests_no_scheduling_when_capacity_exceeded(
# Integration tests combining both methods
def test_filter_and_schedule_integration(attention_dp_queue,
all_ranks_num_active_requests):
all_ranks_num_active_requests,
all_ranks_num_active_tokens):
req_schedulable = RequestQueueItem(
1,
create_mock_request_with_py_schedule_params(attention_dp_rank=0,
@ -689,7 +714,8 @@ def test_filter_and_schedule_integration(attention_dp_queue,
new_requests = [req_schedulable, req_relax, req_no_params]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
result = all_ranks_new_requests[0]
assert len(result) == 2
@ -697,8 +723,9 @@ def test_filter_and_schedule_integration(attention_dp_queue,
assert req_relax in result
def test_filter_and_schedule_with_capacity_limits(
attention_dp_queue, all_ranks_num_active_requests):
def test_filter_and_schedule_with_capacity_limits(attention_dp_queue,
all_ranks_num_active_requests,
all_ranks_num_active_tokens):
all_ranks_num_active_requests[0] = 7
req1 = RequestQueueItem(
@ -715,7 +742,8 @@ def test_filter_and_schedule_with_capacity_limits(
new_requests = [req1, req2]
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
result = all_ranks_new_requests[0]
assert len(result) == 1
@ -838,26 +866,38 @@ def append_to_waiting_queue(waiting_queue, rank, attention_dp_relax):
"max_num_active_requests,all_ranks_num_active_requests,request_configs,all_ranks_expected_req_ids",
[
# Case: Balanced distribution of relaxed requests
(3, [0, 0, 0, 0], [(None, True)] * 7, {
0: [0, 4],
1: [1, 5],
2: [2, 6],
3: [3]
}),
# Case: Balanced distribution of relaxed requests
(3, [1, 2, 3, 0], [(None, True)] * 13, {
0: [1, 4],
1: [2],
2: [],
3: [0, 3, 5]
}),
(
3,
[0, 0, 0, 0],
[(None, True)] * 7,
{
0: [0, 1], # First 2 requests go to rank 0
1: [2, 3], # Next 2 requests go to rank 1
2: [4, 5], # Next 2 requests go to rank 2
3: [6] # Last request goes to rank 3
}),
# Case: Balanced distribution of relaxed requests with existing load
(
3,
[1, 2, 3, 0],
[(None, True)] * 13,
{
0: [0, 1], # Rank 0 gets first 2 requests
1: [2], # Rank 1 gets 1 request (already has 2)
2: [], # Rank 2 is at capacity (3)
3: [3, 4, 5] # Rank 3 gets 3 requests (starts with 0)
}),
# Case: Limited by max active
(3, [0, 0, 0, 0], [(None, True)] * 13, {
0: [0, 4, 8],
1: [1, 5, 9],
2: [2, 6, 10],
3: [3, 7, 11]
}),
(
3,
[0, 0, 0, 0],
[(None, True)] * 13,
{
0: [0, 1, 3], # First 3 requests (0, 1, 3)
1: [2, 4, 6], # Next 3 requests (2, 4, 6)
2: [5, 7, 9], # Next 3 requests (5, 7, 9)
3: [8, 10, 11] # Last 3 requests (8, 10, 11)
}),
# Case: Empty new requests
(3, [3, 3, 3, 0], [], {
0: [],
@ -866,40 +906,60 @@ def append_to_waiting_queue(waiting_queue, rank, attention_dp_relax):
3: []
}),
# Case: Rank 0 is full and cannot schedule attention_dp rank request
(3, [3, 1, 3, 0], [(0, False), (0, True)], {
0: [],
1: [],
2: [],
3: [1]
}),
(
3,
[3, 1, 3, 0],
[(0, False), (0, True)],
{
0: [], # Rank 0 is full
1: [1], # Rank 1 gets the relaxed request (req1)
2: [], # No relaxed requests assigned here
3: [] # No relaxed requests assigned here
}),
# Case: Only room for 1 request, need to skip req0 with attention dp rank
(3, [3, 2, 3, 3], [(0, False), (0, True)], {
0: [],
1: [1],
2: [],
3: []
}),
(
3,
[3, 2, 3, 3],
[(0, False), (0, True)],
{
0: [], # Rank 0 is full
1: [1], # Rank 1 gets the relaxed request
2: [], # Rank 2 is at capacity
3: [] # Rank 3 is at capacity
}),
# Case: Targeting ranks 1 and 3 that have room
(3, [2, 1, 3, 0], [(1, False), (3, False)], {
0: [],
1: [0],
2: [],
3: [1]
}),
# Case: Target dp rank specified, by relax is True
(3, [3, 3, 3, 1], [(0, True), (1, True), (2, True)], {
0: [],
1: [],
2: [],
3: [0, 1]
}),
# Case:
(3, [3, 3, 3, 0], [(0, False), (1, True), (3, False)], {
0: [],
1: [],
2: [],
3: [2, 1]
}),
(
3,
[2, 1, 3, 0],
[(1, False), (3, False)],
{
0: [], # No requests assigned to rank 0
1: [0], # Request 0 targets rank 1
2: [], # No requests assigned to rank 2
3: [1] # Request 1 targets rank 3
}),
# Case: Target dp rank specified, but relax is True
(
3,
[3, 3, 3, 1],
[(0, True), (1, True), (2, True)],
{
0: [], # Rank 0 is at capacity
1: [], # Rank 1 is at capacity
2: [], # Rank 2 is at capacity
3: [0, 1] # Rank 3 gets both relaxed requests
}),
# Case: Mixed targeting and relaxed
(
3,
[3, 3, 3, 0],
[(0, False), (1, True), (3, False)],
{
0: [], # Rank 0 is at capacity
1: [], # Rank 1 is at capacity
2: [], # Rank 2 is at capacity
3: [2, 1] # Rank 3 gets both requests (targeted + relaxed)
}),
])
def test_attention_dp_scheduling_cases(attention_dp_queue,
max_num_active_requests,
@ -932,12 +992,226 @@ def run_test_attention_dp_scheduling(attention_dp_queue, waiting_queue,
total_max_num_active_requests - total_num_active_requests,
enable_attention_dp, all_ranks_num_active_requests)
# Create mock token counts for testing
all_ranks_num_active_tokens = [10 + i * 5 for i in range(num_ranks)]
# Schedule attention dp requests
all_ranks_new_requests = attention_dp_queue._schedule_attention_dp_requests(
new_requests, all_ranks_num_active_requests)
new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
assert len(all_ranks_new_requests) == num_ranks
print("all_ranks_new_requests:", all_ranks_new_requests)
for rank, reqs in all_ranks_new_requests.items():
req_ids = [req.id for req in reqs]
assert req_ids == all_ranks_expected_req_ids[rank]
# New tests for _balance_requests_across_ranks method
def test_balance_requests_across_ranks_empty_requests(attention_dp_queue):
"""Test _balance_requests_across_ranks with empty requests list."""
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
all_ranks_num_active_requests = [2, 1, 3, 0]
all_ranks_num_active_tokens = [20, 10, 30, 5]
# Set expected_num_active_requests for testing
attention_dp_queue.expected_num_active_requests = 3
result = attention_dp_queue._balance_requests_across_ranks(
[], all_ranks_new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
# Should return the original structure unchanged
assert result == all_ranks_new_requests
for rank in range(4):
assert len(result[rank]) == 0
def test_balance_requests_across_ranks_single_request(attention_dp_queue):
"""Test _balance_requests_across_ranks with a single request."""
req = RequestQueueItem(
1, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req.request.input_token_ids = [1, 2, 3, 4, 5] # 5 tokens
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
all_ranks_num_active_requests = [1, 2, 0, 1] # Rank 2 has lowest count
all_ranks_num_active_tokens = [10, 20, 5, 15]
# Set expected_num_active_requests for testing
attention_dp_queue.expected_num_active_requests = 2
result = attention_dp_queue._balance_requests_across_ranks(
[req], all_ranks_new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
# Request should be assigned to rank 2 (lowest active count)
assert len(result[0]) == 0
assert len(result[1]) == 0
assert len(result[2]) == 1
assert len(result[3]) == 0
assert result[2][0] == req
def test_balance_requests_across_ranks_multiple_requests(attention_dp_queue):
"""Test _balance_requests_across_ranks with multiple requests."""
# Create requests with different token counts
req1 = RequestQueueItem(
1, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req1.request.input_token_ids = [1, 2, 3] # 3 tokens
req2 = RequestQueueItem(
2, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req2.request.input_token_ids = [1, 2, 3, 4, 5, 6] # 6 tokens
req3 = RequestQueueItem(
3, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req3.request.input_token_ids = [1, 2] # 2 tokens
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
all_ranks_num_active_requests = [0, 1, 2, 1]
all_ranks_num_active_tokens = [5, 15, 25, 10]
# Set expected_num_active_requests for testing
attention_dp_queue.expected_num_active_requests = 2
result = attention_dp_queue._balance_requests_across_ranks(
[req1, req2, req3], all_ranks_new_requests,
all_ranks_num_active_requests, all_ranks_num_active_tokens)
# Requests should be distributed based on heap (lowest active count first)
# Requests are sorted by token count (descending) first, then assigned to ranks with lowest active count
# req2 (6 tokens) -> rank 0 (0 active) -> total: 1 active, 11 tokens
# req3 (2 tokens) -> rank 0 (1 active) -> total: 2 active, 13 tokens (rank 0 still has capacity)
# req1 (3 tokens) -> rank 3 (1 active) -> total: 2 active, 13 tokens
# Rank 1: 1 active, gets nothing (rank 0 took 2 requests)
# Rank 2: 2 active, gets nothing (at capacity)
assert len(result[0]) == 2 # req2 and req3 (rank 0 has capacity for 2)
assert len(result[1]) == 0 # no requests (rank 0 took 2 requests)
assert len(result[2]) == 0 # at capacity
assert len(result[3]) == 1 # req1
# Verify the requests are assigned correctly
assert result[0][0] == req2 # First request (highest token count)
assert result[0][1] == req3 # Second request
assert result[3][0] == req1
def test_balance_requests_across_ranks_capacity_limits(attention_dp_queue):
"""Test _balance_requests_across_ranks respects capacity limits."""
# Create multiple requests
requests = []
for i in range(4):
req = RequestQueueItem(
i,
create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req.request.input_token_ids = [1] * (i + 1) # Variable token counts
requests.append(req)
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
all_ranks_num_active_requests = [1, 1, 1, 1] # All ranks start with 1
all_ranks_num_active_tokens = [10, 10, 10, 10]
# Set expected_num_active_requests to limit capacity
attention_dp_queue.expected_num_active_requests = 2
result = attention_dp_queue._balance_requests_across_ranks(
requests, all_ranks_new_requests, all_ranks_num_active_requests,
all_ranks_num_active_tokens)
# Each rank can only take 1 more request (1 + 1 = 2, which equals expected_num_active_requests)
total_assigned = sum(
len(rank_requests) for rank_requests in result.values())
assert total_assigned == 4 # 4 ranks × 1 additional request each
# Verify no rank exceeds capacity
for rank in range(4):
assert len(result[rank]) <= 1
def test_balance_requests_across_ranks_heap_ordering(attention_dp_queue):
"""Test that _balance_requests_across_ranks uses heap ordering correctly."""
# Create requests with same token count to test heap ordering
req1 = RequestQueueItem(
1, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req1.request.input_token_ids = [1, 2, 3] # 3 tokens
req2 = RequestQueueItem(
2, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req2.request.input_token_ids = [1, 2, 3] # 3 tokens
req3 = RequestQueueItem(
3, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req3.request.input_token_ids = [1, 2, 3] # 3 tokens
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
# Rank 0 has highest active count, should get requests last
all_ranks_num_active_requests = [3, 1, 0, 2]
all_ranks_num_active_tokens = [30, 10, 5, 20]
# Set expected_num_active_requests for testing
attention_dp_queue.expected_num_active_requests = 4
result = attention_dp_queue._balance_requests_across_ranks(
[req1, req2, req3], all_ranks_new_requests,
all_ranks_num_active_requests, all_ranks_num_active_tokens)
# Requests should be assigned in order of lowest active count first
# Since all requests have same token count, they're assigned based on active count order
# Rank 2: 0 active -> gets req1 and req2 (has capacity for 2)
# Rank 1: 1 active -> gets req3 (after rank 2 takes 2)
# Rank 3: 2 active -> gets nothing (rank 1 took req3)
# Rank 0: 3 active -> gets nothing (at capacity)
assert len(result[0]) == 0 # at capacity
assert len(result[1]) == 1 # req3
assert len(result[2]) == 2 # req1 and req2
assert len(result[3]) == 0 # no requests
# Verify the requests are assigned correctly
assert result[1][0] == req3 # Third request
assert result[2][0] == req1 # First request
assert result[2][1] == req2 # Second request
def test_balance_requests_across_ranks_token_count_sorting(attention_dp_queue):
"""Test that requests are sorted by token count before distribution."""
# Create requests with different token counts
req1 = RequestQueueItem(
1, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req1.request.input_token_ids = [1] # 1 token (smallest)
req2 = RequestQueueItem(
2, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req2.request.input_token_ids = [1, 2, 3, 4, 5] # 5 tokens (largest)
req3 = RequestQueueItem(
3, create_mock_request_with_py_schedule_params(attention_dp_rank=None))
req3.request.input_token_ids = [1, 2, 3] # 3 tokens (medium)
all_ranks_new_requests = {0: [], 1: [], 2: [], 3: []}
all_ranks_num_active_requests = [0, 0, 0, 0] # All ranks start empty
all_ranks_num_active_tokens = [5, 5, 5, 5]
# Set expected_num_active_requests for testing
attention_dp_queue.expected_num_active_requests = 2
result = attention_dp_queue._balance_requests_across_ranks(
[req1, req2, req3], all_ranks_new_requests,
all_ranks_num_active_requests, all_ranks_num_active_tokens)
# Requests should be sorted by token count (descending) before distribution
# Then assigned to ranks with lowest active count first
# req2 (5 tokens) -> rank 0 (0 active)
# req3 (3 tokens) -> rank 1 (0 active)
# req1 (1 token) -> rank 2 (0 active)
assert len(result[0]) == 1 # req2 (highest token count)
assert len(result[1]) == 1 # req3
assert len(result[2]) == 1 # req1 (lowest token count)
assert len(result[3]) == 0
# Verify the requests are assigned correctly
assert result[0][0] == req2
assert result[1][0] == req3
assert result[2][0] == req1