[TRTLLM-5188] fix: [AutoDeploy] unwaive AD build test (#4273)

* unwaive small build test

Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com>

* unwaive mutigpu/integration tests

Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com>

* fix for torch.compile+flashinfer attention

Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com>

---------

Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Fridah-nv 2025-05-13 19:40:12 -07:00 committed by GitHub
parent 23b9705bf4
commit 21dbd163a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 5 additions and 7 deletions

View File

@ -210,20 +210,21 @@ def prepare_flashinfer_metadata(
)
# TODO: Move the truncation of seq_len out of this custom op
# As SequenceInfo._get_sanitized_num_sequences could break in fake mode
@prepare_flashinfer_metadata.register_fake
def prepare_flashinfer_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
):
seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)
batch_indices = torch.empty_like(cache_loc)
positions = torch.empty_like(cache_loc)
return (
qo_indptr, # qo_indptr
torch.empty_like(qo_indptr), # paged_kv_indptr
torch.empty_like(cache_loc), # paged_kv_indices
torch.empty_like(seq_len), # paged_kv_last_page_len
batch_indices, # batch_indices
positions, # positions
torch.empty_like(seq_len), # batch_indices
torch.empty_like(seq_len), # positions
)

View File

@ -123,7 +123,6 @@ from utils.llm_data import llm_models_root
],
)
def test_build_ad(world_size: Optional[int], config: Dict):
pytest.skip("https://nvbugs/5271004")
simple_config = SimpleConfig(**config)
simple_config.world_size = world_size
main(simple_config)

View File

@ -57,7 +57,6 @@ from utils.llm_data import llm_models_root
],
)
def test_build_ad(world_size: Optional[int], config: Dict):
pytest.skip("https://nvbugs/5271004")
simple_config = SimpleConfig(**config)
simple_config.world_size = world_size
main(simple_config)

View File

@ -71,7 +71,6 @@ from utils.llm_data import llm_models_root
],
)
def test_build_ad(world_size: Optional[int], config: Dict):
pytest.skip("https://nvbugs/5271004")
simple_config = SimpleConfig(**config)
simple_config.world_size = world_size
main(simple_config)