diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 973765b2f0..6a70628f90 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -602,9 +602,9 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): @parametrize_with_ids("gen_tp", [1, 2]) @pytest.mark.parametrize("testset", ["GSM8K", "MMLU"]) def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset): - if ctx_pp * gen_tp * 2 > get_device_count(): + if ctx_pp + gen_tp > get_device_count(): pytest.skip( - f"Not enough devices for ctx_pp={ctx_pp}*gen_tp={gen_tp} test") + f"Not enough devices for ctx_pp={ctx_pp}+gen_tp={gen_tp} test") return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1, gen_tp, 1, 1, [get_accuracy_task(testset)])