mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-5188] fix: [AutoDeploy] unwaive AD build test (#4273)
* unwaive small build test Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com> * unwaive mutigpu/integration tests Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com> * fix for torch.compile+flashinfer attention Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com> --------- Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
23b9705bf4
commit
21dbd163a7
@ -210,20 +210,21 @@ def prepare_flashinfer_metadata(
|
||||
)
|
||||
|
||||
|
||||
# TODO: Move the truncation of seq_len out of this custom op
|
||||
# As SequenceInfo._get_sanitized_num_sequences could break in fake mode
|
||||
@prepare_flashinfer_metadata.register_fake
|
||||
def prepare_flashinfer_metadata_fake(
|
||||
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
|
||||
):
|
||||
seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
|
||||
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)
|
||||
batch_indices = torch.empty_like(cache_loc)
|
||||
positions = torch.empty_like(cache_loc)
|
||||
return (
|
||||
qo_indptr, # qo_indptr
|
||||
torch.empty_like(qo_indptr), # paged_kv_indptr
|
||||
torch.empty_like(cache_loc), # paged_kv_indices
|
||||
torch.empty_like(seq_len), # paged_kv_last_page_len
|
||||
batch_indices, # batch_indices
|
||||
positions, # positions
|
||||
torch.empty_like(seq_len), # batch_indices
|
||||
torch.empty_like(seq_len), # positions
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -123,7 +123,6 @@ from utils.llm_data import llm_models_root
|
||||
],
|
||||
)
|
||||
def test_build_ad(world_size: Optional[int], config: Dict):
|
||||
pytest.skip("https://nvbugs/5271004")
|
||||
simple_config = SimpleConfig(**config)
|
||||
simple_config.world_size = world_size
|
||||
main(simple_config)
|
||||
|
||||
@ -57,7 +57,6 @@ from utils.llm_data import llm_models_root
|
||||
],
|
||||
)
|
||||
def test_build_ad(world_size: Optional[int], config: Dict):
|
||||
pytest.skip("https://nvbugs/5271004")
|
||||
simple_config = SimpleConfig(**config)
|
||||
simple_config.world_size = world_size
|
||||
main(simple_config)
|
||||
|
||||
@ -71,7 +71,6 @@ from utils.llm_data import llm_models_root
|
||||
],
|
||||
)
|
||||
def test_build_ad(world_size: Optional[int], config: Dict):
|
||||
pytest.skip("https://nvbugs/5271004")
|
||||
simple_config = SimpleConfig(**config)
|
||||
simple_config.world_size = world_size
|
||||
main(simple_config)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user