diff --git a/tests/unittest/_torch/test_attention.py b/tests/unittest/_torch/test_attention.py index ab0f1e6a7a..76782e60a9 100644 --- a/tests/unittest/_torch/test_attention.py +++ b/tests/unittest/_torch/test_attention.py @@ -547,11 +547,11 @@ def generate_causal_mask(seq_lens, qo_lens, batch_size, dtype): @pytest.mark.parametrize("s", [ - PagedScenario(num_layers=32, num_generations=5), - PagedScenario(num_layers=32, num_generations=5, kv_len=64, causal=False), + PagedScenario(num_layers=4, num_generations=5), + PagedScenario(num_layers=4, num_generations=5, kv_len=64, causal=False), PagedScenario( - num_layers=32, num_generations=5, kvcache_dtype=torch.float8_e4m3fn), - PagedScenario(num_layers=32, + num_layers=4, num_generations=5, kvcache_dtype=torch.float8_e4m3fn), + PagedScenario(num_layers=4, num_generations=5, kv_len=64, causal=False, @@ -705,8 +705,3 @@ def test_attention_backend_ifb(s: PagedScenario): del ref_kv_cache del vanilla_kv_cache torch.cuda.empty_cache() - - -if __name__ == "__main__": - test_attention_backend(Scenario(num_layers=1)) - # test_attention_backend(Scenario(num_layers=1, qo_len=32, kv_len=32, causal=False))