From 21dbd163a7fdd98ea8af7ca895f46c0dd6205b7f Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Tue, 13 May 2025 19:40:12 -0700 Subject: [PATCH] [TRTLLM-5188] fix: [AutoDeploy] unwaive AD build test (#4273) * unwaive small build test Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com> * unwaive mutigpu/integration tests Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com> * fix for torch.compile+flashinfer attention Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com> --------- Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com> --- .../auto_deploy/custom_ops/flashinfer_attention.py | 9 +++++---- .../_torch/auto_deploy/integration/test_ad_build.py | 1 - .../unit/multigpu/test_ad_build_small_multi.py | 1 - .../unit/singlegpu/test_ad_build_small_single.py | 1 - 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 8ff30f6502..b24be0193a 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -210,20 +210,21 @@ def prepare_flashinfer_metadata( ) +# TODO: Move the truncation of seq_len out of this custom op +# As SequenceInfo._get_sanitized_num_sequences could break in fake mode @prepare_flashinfer_metadata.register_fake def prepare_flashinfer_metadata_fake( input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size ): + seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device) - batch_indices = torch.empty_like(cache_loc) - positions = torch.empty_like(cache_loc) return ( qo_indptr, # qo_indptr torch.empty_like(qo_indptr), # paged_kv_indptr torch.empty_like(cache_loc), # paged_kv_indices torch.empty_like(seq_len), # paged_kv_last_page_len - batch_indices, # batch_indices - positions, # positions + torch.empty_like(seq_len), # batch_indices + torch.empty_like(seq_len), # positions ) diff --git a/tests/unittest/_torch/auto_deploy/integration/test_ad_build.py b/tests/unittest/_torch/auto_deploy/integration/test_ad_build.py index 51af8df700..2c6cc9755d 100644 --- a/tests/unittest/_torch/auto_deploy/integration/test_ad_build.py +++ b/tests/unittest/_torch/auto_deploy/integration/test_ad_build.py @@ -123,7 +123,6 @@ from utils.llm_data import llm_models_root ], ) def test_build_ad(world_size: Optional[int], config: Dict): - pytest.skip("https://nvbugs/5271004") simple_config = SimpleConfig(**config) simple_config.world_size = world_size main(simple_config) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py index 80f002d9f2..709cec1dc9 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py @@ -57,7 +57,6 @@ from utils.llm_data import llm_models_root ], ) def test_build_ad(world_size: Optional[int], config: Dict): - pytest.skip("https://nvbugs/5271004") simple_config = SimpleConfig(**config) simple_config.world_size = world_size main(simple_config) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 28390fbc84..968e5f1dcb 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -71,7 +71,6 @@ from utils.llm_data import llm_models_root ], ) def test_build_ad(world_size: Optional[int], config: Dict): - pytest.skip("https://nvbugs/5271004") simple_config = SimpleConfig(**config) simple_config.world_size = world_size main(simple_config)