From 3e7703676829a2698d8f33027c64296e5135832b Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Thu, 4 Jun 2026 09:35:00 -0500 Subject: [PATCH] [ROCm][CI] Specifying time outs for the lm eval models (#44255) Signed-off-by: Andreas Karatzas --- .../configs/DeepSeek-V2-Lite-Instruct-FP8.yaml | 1 + tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml | 1 + tests/evals/gsm8k/gsm8k_eval.py | 8 ++++---- tests/evals/gsm8k/test_gsm8k_correctness.py | 13 +++++++++++++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml b/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml index 72fa7e8a38c..dde67727bc6 100644 --- a/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml +++ b/tests/evals/gsm8k/configs/DeepSeek-V2-Lite-Instruct-FP8.yaml @@ -2,4 +2,5 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8" accuracy_threshold: 0.72 num_questions: 1319 num_fewshot: 5 +rocm_request_timeout_seconds: 1800 server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml index 4a1b1948aca..027b4ba5622 100644 --- a/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml +++ b/tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml @@ -2,4 +2,5 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16" accuracy_threshold: 0.45 num_questions: 1319 num_fewshot: 5 +rocm_request_timeout_seconds: 1800 server_args: "--enforce-eager --max-model-len 4096" diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py index 647c149ef5f..ff0718cd2aa 100644 --- a/tests/evals/gsm8k/gsm8k_eval.py +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -106,7 +106,7 @@ async def call_vllm_api( completion_tokens = result.get("usage", {}).get("completion_tokens", 0) return text, completion_tokens except Exception as e: - print(f"Error calling vLLM API: {e}") + print(f"Error calling vLLM API ({type(e).__name__}): {e}") return "", 0 @@ -177,6 +177,7 @@ def evaluate_gsm8k( port: int = 8000, temperature: float = 0.0, seed: int | None = 42, + request_timeout_seconds: float = 600, ) -> dict[str, float | int]: """ Evaluate GSM8K accuracy using vLLM serve endpoint. @@ -205,9 +206,8 @@ def evaluate_gsm8k( output_tokens[i] = tokens return answer, tokens - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=600) - ) as session: + timeout = aiohttp.ClientTimeout(total=request_timeout_seconds) + async with aiohttp.ClientSession(timeout=timeout) as session: tasks = [get_answer(session, i) for i in range(num_questions)] await tqdm.gather(*tasks, desc="Evaluating") diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py index 57513e18aba..e7a254e760f 100644 --- a/tests/evals/gsm8k/test_gsm8k_correctness.py +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -39,11 +39,18 @@ def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict: host = f"http://{host}" # Run GSM8K evaluation + request_timeout_seconds = eval_config.get("request_timeout_seconds", 600) + if current_platform.is_rocm(): + request_timeout_seconds = eval_config.get( + "rocm_request_timeout_seconds", request_timeout_seconds + ) + results = evaluate_gsm8k( num_questions=eval_config["num_questions"], num_shots=eval_config["num_fewshot"], host=host, port=port, + request_timeout_seconds=request_timeout_seconds, ) return results @@ -90,6 +97,12 @@ def test_gsm8k_correctness(config_filename): print(f"Expected metric threshold: {eval_config['accuracy_threshold']}") print(f"Number of questions: {eval_config['num_questions']}") print(f"Number of few-shot examples: {eval_config['num_fewshot']}") + request_timeout_seconds = eval_config.get("request_timeout_seconds", 600) + if current_platform.is_rocm(): + request_timeout_seconds = eval_config.get( + "rocm_request_timeout_seconds", request_timeout_seconds + ) + print(f"Request timeout: {request_timeout_seconds}s") print(f"Server args: {' '.join(server_args)}") print(f"Environment variables: {env_dict}")