From a11f0dbd6118a71b0d15ea67510eb4ec9493c913 Mon Sep 17 00:00:00 2001 From: Eran Geva <19514940+MrGeva@users.noreply.github.com> Date: Sun, 18 Jan 2026 10:42:49 +0200 Subject: [PATCH] [#10696][fix] AutoDeploy prevent torch.export from specializing batch dimension when max_batch_size=1 (#10697) Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com> --- .../custom_ops/attention_interface.py | 4 +- .../unit/singlegpu/shim/test_llm_config.py | 46 +++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index ea583d84aa..3131d87cf8 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -725,7 +725,9 @@ class SequenceInfo: """Set an example sequence useful for testing and export purposes without cache history.""" # use a best guess default for input_ids if not provided if input_ids is None: - bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len) + # Use batch_size >= 2 for export to prevent torch.export from specializing + # the batch dimension when max_batch_size=1 (dimension value 1 triggers static optimization) + bs, seq_len = max(2, min(2, self.max_batch_size)), min(4, self.max_seq_len) input_ids = torch.ones(bs, seq_len, dtype=torch.int).tolist() # figure out page assignments diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py index 04fa1f91fb..d797e07d3a 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_llm_config.py @@ -230,3 +230,49 @@ def test_attention_backend_page_size_logic(backend, expected_attn_page_size): transforms={"insert_cached_attention": {"stage": "cache_init", "backend": backend}}, ) assert args.attn_page_size == expected_attn_page_size + + +class TestSequenceInfoExampleBatchSize: + """Test that SequenceInfo generates proper example batch sizes for export.""" + + def test_example_batch_size_at_least_2_when_max_batch_size_1(self): + """Test that example batch size is at least 2 even when max_batch_size=1. + + This is critical because torch.export specializes dimensions when the + example input has a dimension value of 1, breaking dynamic batching. + """ + from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo + + seq_info = SequenceInfo( + max_batch_size=1, + max_seq_len=128, + max_num_tokens=128, + page_size=64, + ) + + # Set example sequence (this is what's used during export) + seq_info.set_example_sequence() + + # The example batch size should be at least 2 to prevent torch.export + # from specializing the batch dimension + assert len(seq_info.named_args["input_ids"]) >= 2, ( + f"Example batch size should be >= 2 for export, got {len(seq_info.named_args['input_ids'])}" + ) + + def test_example_batch_size_normal_max_batch_size(self): + """Test example batch size with normal max_batch_size.""" + from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo + + seq_info = SequenceInfo( + max_batch_size=32, + max_seq_len=128, + max_num_tokens=128, + page_size=64, + ) + + seq_info.set_example_sequence() + + # With larger max_batch_size, example should still be 2 + assert len(seq_info.named_args["input_ids"]) == 2, ( + f"Expected example batch size of 2, got {len(seq_info.named_args['input_ids'])}" + )