[#10696][fix] AutoDeploy prevent torch.export from specializing batch dimension when max_batch_size=1 (#10697)

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
This commit is contained in:
Eran Geva 2026-01-18 10:42:49 +02:00 committed by GitHub
parent 0af1a0e478
commit a11f0dbd61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 1 deletions

View File

@ -725,7 +725,9 @@ class SequenceInfo:
"""Set an example sequence useful for testing and export purposes without cache history."""
# use a best guess default for input_ids if not provided
if input_ids is None:
bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len)
# Use batch_size >= 2 for export to prevent torch.export from specializing
# the batch dimension when max_batch_size=1 (dimension value 1 triggers static optimization)
bs, seq_len = max(2, min(2, self.max_batch_size)), min(4, self.max_seq_len)
input_ids = torch.ones(bs, seq_len, dtype=torch.int).tolist()
# figure out page assignments

View File

@ -230,3 +230,49 @@ def test_attention_backend_page_size_logic(backend, expected_attn_page_size):
transforms={"insert_cached_attention": {"stage": "cache_init", "backend": backend}},
)
assert args.attn_page_size == expected_attn_page_size
class TestSequenceInfoExampleBatchSize:
"""Test that SequenceInfo generates proper example batch sizes for export."""
def test_example_batch_size_at_least_2_when_max_batch_size_1(self):
"""Test that example batch size is at least 2 even when max_batch_size=1.
This is critical because torch.export specializes dimensions when the
example input has a dimension value of 1, breaking dynamic batching.
"""
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
seq_info = SequenceInfo(
max_batch_size=1,
max_seq_len=128,
max_num_tokens=128,
page_size=64,
)
# Set example sequence (this is what's used during export)
seq_info.set_example_sequence()
# The example batch size should be at least 2 to prevent torch.export
# from specializing the batch dimension
assert len(seq_info.named_args["input_ids"]) >= 2, (
f"Example batch size should be >= 2 for export, got {len(seq_info.named_args['input_ids'])}"
)
def test_example_batch_size_normal_max_batch_size(self):
"""Test example batch size with normal max_batch_size."""
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
seq_info = SequenceInfo(
max_batch_size=32,
max_seq_len=128,
max_num_tokens=128,
page_size=64,
)
seq_info.set_example_sequence()
# With larger max_batch_size, example should still be 2
assert len(seq_info.named_args["input_ids"]) == 2, (
f"Expected example batch size of 2, got {len(seq_info.named_args['input_ids'])}"
)