mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 02:02:01 +08:00
[#10696][fix] AutoDeploy prevent torch.export from specializing batch dimension when max_batch_size=1 (#10697)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
This commit is contained in:
parent
0af1a0e478
commit
a11f0dbd61
@ -725,7 +725,9 @@ class SequenceInfo:
|
||||
"""Set an example sequence useful for testing and export purposes without cache history."""
|
||||
# use a best guess default for input_ids if not provided
|
||||
if input_ids is None:
|
||||
bs, seq_len = min(2, self.max_batch_size), min(4, self.max_seq_len)
|
||||
# Use batch_size >= 2 for export to prevent torch.export from specializing
|
||||
# the batch dimension when max_batch_size=1 (dimension value 1 triggers static optimization)
|
||||
bs, seq_len = max(2, min(2, self.max_batch_size)), min(4, self.max_seq_len)
|
||||
input_ids = torch.ones(bs, seq_len, dtype=torch.int).tolist()
|
||||
|
||||
# figure out page assignments
|
||||
|
||||
@ -230,3 +230,49 @@ def test_attention_backend_page_size_logic(backend, expected_attn_page_size):
|
||||
transforms={"insert_cached_attention": {"stage": "cache_init", "backend": backend}},
|
||||
)
|
||||
assert args.attn_page_size == expected_attn_page_size
|
||||
|
||||
|
||||
class TestSequenceInfoExampleBatchSize:
|
||||
"""Test that SequenceInfo generates proper example batch sizes for export."""
|
||||
|
||||
def test_example_batch_size_at_least_2_when_max_batch_size_1(self):
|
||||
"""Test that example batch size is at least 2 even when max_batch_size=1.
|
||||
|
||||
This is critical because torch.export specializes dimensions when the
|
||||
example input has a dimension value of 1, breaking dynamic batching.
|
||||
"""
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
|
||||
|
||||
seq_info = SequenceInfo(
|
||||
max_batch_size=1,
|
||||
max_seq_len=128,
|
||||
max_num_tokens=128,
|
||||
page_size=64,
|
||||
)
|
||||
|
||||
# Set example sequence (this is what's used during export)
|
||||
seq_info.set_example_sequence()
|
||||
|
||||
# The example batch size should be at least 2 to prevent torch.export
|
||||
# from specializing the batch dimension
|
||||
assert len(seq_info.named_args["input_ids"]) >= 2, (
|
||||
f"Example batch size should be >= 2 for export, got {len(seq_info.named_args['input_ids'])}"
|
||||
)
|
||||
|
||||
def test_example_batch_size_normal_max_batch_size(self):
|
||||
"""Test example batch size with normal max_batch_size."""
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
|
||||
|
||||
seq_info = SequenceInfo(
|
||||
max_batch_size=32,
|
||||
max_seq_len=128,
|
||||
max_num_tokens=128,
|
||||
page_size=64,
|
||||
)
|
||||
|
||||
seq_info.set_example_sequence()
|
||||
|
||||
# With larger max_batch_size, example should still be 2
|
||||
assert len(seq_info.named_args["input_ids"]) == 2, (
|
||||
f"Expected example batch size of 2, got {len(seq_info.named_args['input_ids'])}"
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user