mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 02:02:01 +08:00
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
This commit is contained in:
parent
b64052539d
commit
4d2916d683
@ -280,7 +280,8 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
|
||||
# if not set, use heuristic
|
||||
if self.cuda_graph_batch_sizes is None:
|
||||
cg_bs = {1, self.max_batch_size}
|
||||
cg_bs.update(range(1, 128 + 1, 16))
|
||||
# Only add batch sizes up to max_batch_size
|
||||
cg_bs.update(range(1, min(128, self.max_batch_size) + 1, 16))
|
||||
cg_bs.update(range(128, self.max_batch_size + 1, 128))
|
||||
else:
|
||||
cg_bs = [b for b in self.cuda_graph_batch_sizes if b <= self.max_batch_size]
|
||||
|
||||
@ -232,6 +232,110 @@ def test_attention_backend_page_size_logic(backend, expected_attn_page_size):
|
||||
assert args.attn_page_size == expected_attn_page_size
|
||||
|
||||
|
||||
# ================================
|
||||
# CUDA Graph Batch Sizes Tests
|
||||
# ================================
|
||||
|
||||
|
||||
class TestCudaGraphBatchSizesHeuristic:
|
||||
"""Test that cuda_graph_batch_sizes heuristic respects max_batch_size."""
|
||||
|
||||
def test_small_max_batch_size_caps_heuristic(self):
|
||||
"""Test that heuristic batch sizes are capped at small max_batch_size.
|
||||
|
||||
When max_batch_size is small (e.g., 4), the heuristic should NOT include
|
||||
batch sizes like 17, 33, 49, 65, 81, 97, 113 which exceed max_batch_size.
|
||||
"""
|
||||
args = LlmArgs(
|
||||
model="test-model",
|
||||
max_batch_size=4,
|
||||
)
|
||||
|
||||
# All batch sizes should be <= max_batch_size
|
||||
assert all(bs <= 4 for bs in args.cuda_graph_batch_sizes), (
|
||||
f"Expected all batch sizes <= 4, got {args.cuda_graph_batch_sizes}"
|
||||
)
|
||||
# Should include 1 and max_batch_size
|
||||
assert 1 in args.cuda_graph_batch_sizes
|
||||
assert 4 in args.cuda_graph_batch_sizes
|
||||
# Should NOT include heuristic values that exceed max_batch_size
|
||||
assert 17 not in args.cuda_graph_batch_sizes
|
||||
assert 113 not in args.cuda_graph_batch_sizes
|
||||
|
||||
def test_medium_max_batch_size_caps_heuristic(self):
|
||||
"""Test heuristic with medium max_batch_size (e.g., 64)."""
|
||||
args = LlmArgs(
|
||||
model="test-model",
|
||||
max_batch_size=64,
|
||||
)
|
||||
|
||||
# All batch sizes should be <= max_batch_size
|
||||
assert all(bs <= 64 for bs in args.cuda_graph_batch_sizes), (
|
||||
f"Expected all batch sizes <= 64, got {args.cuda_graph_batch_sizes}"
|
||||
)
|
||||
# Should include some heuristic values up to 64
|
||||
assert 1 in args.cuda_graph_batch_sizes
|
||||
assert 17 in args.cuda_graph_batch_sizes
|
||||
assert 33 in args.cuda_graph_batch_sizes
|
||||
assert 49 in args.cuda_graph_batch_sizes
|
||||
assert 64 in args.cuda_graph_batch_sizes
|
||||
# Should NOT include values > 64
|
||||
assert 65 not in args.cuda_graph_batch_sizes
|
||||
assert 81 not in args.cuda_graph_batch_sizes
|
||||
|
||||
def test_large_max_batch_size_includes_all_heuristic_values(self):
|
||||
"""Test heuristic with large max_batch_size (e.g., 256)."""
|
||||
args = LlmArgs(
|
||||
model="test-model",
|
||||
max_batch_size=256,
|
||||
)
|
||||
|
||||
# All batch sizes should be <= max_batch_size
|
||||
assert all(bs <= 256 for bs in args.cuda_graph_batch_sizes), (
|
||||
f"Expected all batch sizes <= 256, got {args.cuda_graph_batch_sizes}"
|
||||
)
|
||||
# Should include heuristic values from range(1, 129, 16)
|
||||
for bs in [1, 17, 33, 49, 65, 81, 97, 113]:
|
||||
assert bs in args.cuda_graph_batch_sizes, f"Expected {bs} in batch sizes"
|
||||
# Should include 128 from range(128, max_batch_size+1, 128)
|
||||
assert 128 in args.cuda_graph_batch_sizes
|
||||
assert 256 in args.cuda_graph_batch_sizes
|
||||
|
||||
def test_explicit_cuda_graph_batch_sizes_filtered(self):
|
||||
"""Test that explicitly provided batch sizes are filtered to max_batch_size."""
|
||||
args = LlmArgs(
|
||||
model="test-model",
|
||||
max_batch_size=16,
|
||||
cuda_graph_batch_sizes=[1, 4, 8, 16, 32, 64, 128],
|
||||
)
|
||||
|
||||
# Should only include values <= max_batch_size
|
||||
assert all(bs <= 16 for bs in args.cuda_graph_batch_sizes), (
|
||||
f"Expected all batch sizes <= 16, got {args.cuda_graph_batch_sizes}"
|
||||
)
|
||||
# Values <= 16 should be present
|
||||
assert 1 in args.cuda_graph_batch_sizes
|
||||
assert 4 in args.cuda_graph_batch_sizes
|
||||
assert 8 in args.cuda_graph_batch_sizes
|
||||
assert 16 in args.cuda_graph_batch_sizes
|
||||
# Values > 16 should be filtered out
|
||||
assert 32 not in args.cuda_graph_batch_sizes
|
||||
assert 64 not in args.cuda_graph_batch_sizes
|
||||
assert 128 not in args.cuda_graph_batch_sizes
|
||||
|
||||
def test_batch_sizes_sorted_descending(self):
|
||||
"""Test that cuda_graph_batch_sizes are sorted in descending order."""
|
||||
args = LlmArgs(
|
||||
model="test-model",
|
||||
max_batch_size=64,
|
||||
)
|
||||
|
||||
# Should be sorted in descending order
|
||||
assert args.cuda_graph_batch_sizes == sorted(args.cuda_graph_batch_sizes, reverse=True), (
|
||||
f"Expected descending order, got {args.cuda_graph_batch_sizes}"
|
||||
)
|
||||
|
||||
|
||||
class TestSequenceInfoExampleBatchSize:
|
||||
"""Test that SequenceInfo generates proper example batch sizes for export."""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user