[#10688][fix] AutoDeploy Fix CUDA graph batch sizes exceeding max_batch_size (#10687)

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
This commit is contained in:
Eran Geva 2026-01-18 20:31:01 +02:00 committed by GitHub
parent b64052539d
commit 4d2916d683
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 106 additions and 1 deletions

View File

@ -280,7 +280,8 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
# if not set, use heuristic
if self.cuda_graph_batch_sizes is None:
cg_bs = {1, self.max_batch_size}
cg_bs.update(range(1, 128 + 1, 16))
# Only add batch sizes up to max_batch_size
cg_bs.update(range(1, min(128, self.max_batch_size) + 1, 16))
cg_bs.update(range(128, self.max_batch_size + 1, 128))
else:
cg_bs = [b for b in self.cuda_graph_batch_sizes if b <= self.max_batch_size]

View File

@ -232,6 +232,110 @@ def test_attention_backend_page_size_logic(backend, expected_attn_page_size):
assert args.attn_page_size == expected_attn_page_size
# ================================
# CUDA Graph Batch Sizes Tests
# ================================
class TestCudaGraphBatchSizesHeuristic:
"""Test that cuda_graph_batch_sizes heuristic respects max_batch_size."""
def test_small_max_batch_size_caps_heuristic(self):
"""Test that heuristic batch sizes are capped at small max_batch_size.
When max_batch_size is small (e.g., 4), the heuristic should NOT include
batch sizes like 17, 33, 49, 65, 81, 97, 113 which exceed max_batch_size.
"""
args = LlmArgs(
model="test-model",
max_batch_size=4,
)
# All batch sizes should be <= max_batch_size
assert all(bs <= 4 for bs in args.cuda_graph_batch_sizes), (
f"Expected all batch sizes <= 4, got {args.cuda_graph_batch_sizes}"
)
# Should include 1 and max_batch_size
assert 1 in args.cuda_graph_batch_sizes
assert 4 in args.cuda_graph_batch_sizes
# Should NOT include heuristic values that exceed max_batch_size
assert 17 not in args.cuda_graph_batch_sizes
assert 113 not in args.cuda_graph_batch_sizes
def test_medium_max_batch_size_caps_heuristic(self):
"""Test heuristic with medium max_batch_size (e.g., 64)."""
args = LlmArgs(
model="test-model",
max_batch_size=64,
)
# All batch sizes should be <= max_batch_size
assert all(bs <= 64 for bs in args.cuda_graph_batch_sizes), (
f"Expected all batch sizes <= 64, got {args.cuda_graph_batch_sizes}"
)
# Should include some heuristic values up to 64
assert 1 in args.cuda_graph_batch_sizes
assert 17 in args.cuda_graph_batch_sizes
assert 33 in args.cuda_graph_batch_sizes
assert 49 in args.cuda_graph_batch_sizes
assert 64 in args.cuda_graph_batch_sizes
# Should NOT include values > 64
assert 65 not in args.cuda_graph_batch_sizes
assert 81 not in args.cuda_graph_batch_sizes
def test_large_max_batch_size_includes_all_heuristic_values(self):
"""Test heuristic with large max_batch_size (e.g., 256)."""
args = LlmArgs(
model="test-model",
max_batch_size=256,
)
# All batch sizes should be <= max_batch_size
assert all(bs <= 256 for bs in args.cuda_graph_batch_sizes), (
f"Expected all batch sizes <= 256, got {args.cuda_graph_batch_sizes}"
)
# Should include heuristic values from range(1, 129, 16)
for bs in [1, 17, 33, 49, 65, 81, 97, 113]:
assert bs in args.cuda_graph_batch_sizes, f"Expected {bs} in batch sizes"
# Should include 128 from range(128, max_batch_size+1, 128)
assert 128 in args.cuda_graph_batch_sizes
assert 256 in args.cuda_graph_batch_sizes
def test_explicit_cuda_graph_batch_sizes_filtered(self):
"""Test that explicitly provided batch sizes are filtered to max_batch_size."""
args = LlmArgs(
model="test-model",
max_batch_size=16,
cuda_graph_batch_sizes=[1, 4, 8, 16, 32, 64, 128],
)
# Should only include values <= max_batch_size
assert all(bs <= 16 for bs in args.cuda_graph_batch_sizes), (
f"Expected all batch sizes <= 16, got {args.cuda_graph_batch_sizes}"
)
# Values <= 16 should be present
assert 1 in args.cuda_graph_batch_sizes
assert 4 in args.cuda_graph_batch_sizes
assert 8 in args.cuda_graph_batch_sizes
assert 16 in args.cuda_graph_batch_sizes
# Values > 16 should be filtered out
assert 32 not in args.cuda_graph_batch_sizes
assert 64 not in args.cuda_graph_batch_sizes
assert 128 not in args.cuda_graph_batch_sizes
def test_batch_sizes_sorted_descending(self):
"""Test that cuda_graph_batch_sizes are sorted in descending order."""
args = LlmArgs(
model="test-model",
max_batch_size=64,
)
# Should be sorted in descending order
assert args.cuda_graph_batch_sizes == sorted(args.cuda_graph_batch_sizes, reverse=True), (
f"Expected descending order, got {args.cuda_graph_batch_sizes}"
)
class TestSequenceInfoExampleBatchSize:
"""Test that SequenceInfo generates proper example batch sizes for export."""