diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 7c46a48df3..f4f05355d8 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -280,7 +280,8 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): # if not set, use heuristic if self.cuda_graph_batch_sizes is None: cg_bs = {1, self.max_batch_size} - cg_bs.update(range(1, 128 + 1, 16)) + # Only add batch sizes up to max_batch_size + cg_bs.update(range(1, min(128, self.max_batch_size) + 1, 16)) cg_bs.update(range(128, self.max_batch_size + 1, 128)) else: cg_bs = [b for b in self.cuda_graph_batch_sizes if b <= self.max_batch_size] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py index d797e07d3a..8589253a04 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py @@ -232,6 +232,110 @@ def test_attention_backend_page_size_logic(backend, expected_attn_page_size): assert args.attn_page_size == expected_attn_page_size +# ================================ +# CUDA Graph Batch Sizes Tests +# ================================ + + +class TestCudaGraphBatchSizesHeuristic: + """Test that cuda_graph_batch_sizes heuristic respects max_batch_size.""" + + def test_small_max_batch_size_caps_heuristic(self): + """Test that heuristic batch sizes are capped at small max_batch_size. + + When max_batch_size is small (e.g., 4), the heuristic should NOT include + batch sizes like 17, 33, 49, 65, 81, 97, 113 which exceed max_batch_size. + """ + args = LlmArgs( + model="test-model", + max_batch_size=4, + ) + + # All batch sizes should be <= max_batch_size + assert all(bs <= 4 for bs in args.cuda_graph_batch_sizes), ( + f"Expected all batch sizes <= 4, got {args.cuda_graph_batch_sizes}" + ) + # Should include 1 and max_batch_size + assert 1 in args.cuda_graph_batch_sizes + assert 4 in args.cuda_graph_batch_sizes + # Should NOT include heuristic values that exceed max_batch_size + assert 17 not in args.cuda_graph_batch_sizes + assert 113 not in args.cuda_graph_batch_sizes + + def test_medium_max_batch_size_caps_heuristic(self): + """Test heuristic with medium max_batch_size (e.g., 64).""" + args = LlmArgs( + model="test-model", + max_batch_size=64, + ) + + # All batch sizes should be <= max_batch_size + assert all(bs <= 64 for bs in args.cuda_graph_batch_sizes), ( + f"Expected all batch sizes <= 64, got {args.cuda_graph_batch_sizes}" + ) + # Should include some heuristic values up to 64 + assert 1 in args.cuda_graph_batch_sizes + assert 17 in args.cuda_graph_batch_sizes + assert 33 in args.cuda_graph_batch_sizes + assert 49 in args.cuda_graph_batch_sizes + assert 64 in args.cuda_graph_batch_sizes + # Should NOT include values > 64 + assert 65 not in args.cuda_graph_batch_sizes + assert 81 not in args.cuda_graph_batch_sizes + + def test_large_max_batch_size_includes_all_heuristic_values(self): + """Test heuristic with large max_batch_size (e.g., 256).""" + args = LlmArgs( + model="test-model", + max_batch_size=256, + ) + + # All batch sizes should be <= max_batch_size + assert all(bs <= 256 for bs in args.cuda_graph_batch_sizes), ( + f"Expected all batch sizes <= 256, got {args.cuda_graph_batch_sizes}" + ) + # Should include heuristic values from range(1, 129, 16) + for bs in [1, 17, 33, 49, 65, 81, 97, 113]: + assert bs in args.cuda_graph_batch_sizes, f"Expected {bs} in batch sizes" + # Should include 128 from range(128, max_batch_size+1, 128) + assert 128 in args.cuda_graph_batch_sizes + assert 256 in args.cuda_graph_batch_sizes + + def test_explicit_cuda_graph_batch_sizes_filtered(self): + """Test that explicitly provided batch sizes are filtered to max_batch_size.""" + args = LlmArgs( + model="test-model", + max_batch_size=16, + cuda_graph_batch_sizes=[1, 4, 8, 16, 32, 64, 128], + ) + + # Should only include values <= max_batch_size + assert all(bs <= 16 for bs in args.cuda_graph_batch_sizes), ( + f"Expected all batch sizes <= 16, got {args.cuda_graph_batch_sizes}" + ) + # Values <= 16 should be present + assert 1 in args.cuda_graph_batch_sizes + assert 4 in args.cuda_graph_batch_sizes + assert 8 in args.cuda_graph_batch_sizes + assert 16 in args.cuda_graph_batch_sizes + # Values > 16 should be filtered out + assert 32 not in args.cuda_graph_batch_sizes + assert 64 not in args.cuda_graph_batch_sizes + assert 128 not in args.cuda_graph_batch_sizes + + def test_batch_sizes_sorted_descending(self): + """Test that cuda_graph_batch_sizes are sorted in descending order.""" + args = LlmArgs( + model="test-model", + max_batch_size=64, + ) + + # Should be sorted in descending order + assert args.cuda_graph_batch_sizes == sorted(args.cuda_graph_batch_sizes, reverse=True), ( + f"Expected descending order, got {args.cuda_graph_batch_sizes}" + ) + + class TestSequenceInfoExampleBatchSize: """Test that SequenceInfo generates proper example batch sizes for export."""