[ROCm][CI] Specifying time outs for the lm eval models (#44255)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-06-04 09:35:00 -05:00
committed by GitHub
parent 6f68ca3e91
commit 3e77036768
4 changed files with 19 additions and 4 deletions
@@ -2,4 +2,5 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
accuracy_threshold: 0.72
num_questions: 1319
num_fewshot: 5
rocm_request_timeout_seconds: 1800
server_args: "--enforce-eager --max-model-len 4096"
@@ -2,4 +2,5 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
accuracy_threshold: 0.45
num_questions: 1319
num_fewshot: 5
rocm_request_timeout_seconds: 1800
server_args: "--enforce-eager --max-model-len 4096"
+4 -4
View File
@@ -106,7 +106,7 @@ async def call_vllm_api(
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
return text, completion_tokens
except Exception as e:
print(f"Error calling vLLM API: {e}")
print(f"Error calling vLLM API ({type(e).__name__}): {e}")
return "", 0
@@ -177,6 +177,7 @@ def evaluate_gsm8k(
port: int = 8000,
temperature: float = 0.0,
seed: int | None = 42,
request_timeout_seconds: float = 600,
) -> dict[str, float | int]:
"""
Evaluate GSM8K accuracy using vLLM serve endpoint.
@@ -205,9 +206,8 @@ def evaluate_gsm8k(
output_tokens[i] = tokens
return answer, tokens
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=600)
) as session:
timeout = aiohttp.ClientTimeout(total=request_timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session:
tasks = [get_answer(session, i) for i in range(num_questions)]
await tqdm.gather(*tasks, desc="Evaluating")
@@ -39,11 +39,18 @@ def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
host = f"http://{host}"
# Run GSM8K evaluation
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
if current_platform.is_rocm():
request_timeout_seconds = eval_config.get(
"rocm_request_timeout_seconds", request_timeout_seconds
)
results = evaluate_gsm8k(
num_questions=eval_config["num_questions"],
num_shots=eval_config["num_fewshot"],
host=host,
port=port,
request_timeout_seconds=request_timeout_seconds,
)
return results
@@ -90,6 +97,12 @@ def test_gsm8k_correctness(config_filename):
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
print(f"Number of questions: {eval_config['num_questions']}")
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
if current_platform.is_rocm():
request_timeout_seconds = eval_config.get(
"rocm_request_timeout_seconds", request_timeout_seconds
)
print(f"Request timeout: {request_timeout_seconds}s")
print(f"Server args: {' '.join(server_args)}")
print(f"Environment variables: {env_dict}")