mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm][CI] Specifying time outs for the lm eval models (#44255)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -2,4 +2,5 @@ model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
accuracy_threshold: 0.72
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
rocm_request_timeout_seconds: 1800
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@@ -2,4 +2,5 @@ model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
accuracy_threshold: 0.45
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
rocm_request_timeout_seconds: 1800
|
||||
server_args: "--enforce-eager --max-model-len 4096"
|
||||
|
||||
@@ -106,7 +106,7 @@ async def call_vllm_api(
|
||||
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
|
||||
return text, completion_tokens
|
||||
except Exception as e:
|
||||
print(f"Error calling vLLM API: {e}")
|
||||
print(f"Error calling vLLM API ({type(e).__name__}): {e}")
|
||||
return "", 0
|
||||
|
||||
|
||||
@@ -177,6 +177,7 @@ def evaluate_gsm8k(
|
||||
port: int = 8000,
|
||||
temperature: float = 0.0,
|
||||
seed: int | None = 42,
|
||||
request_timeout_seconds: float = 600,
|
||||
) -> dict[str, float | int]:
|
||||
"""
|
||||
Evaluate GSM8K accuracy using vLLM serve endpoint.
|
||||
@@ -205,9 +206,8 @@ def evaluate_gsm8k(
|
||||
output_tokens[i] = tokens
|
||||
return answer, tokens
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=600)
|
||||
) as session:
|
||||
timeout = aiohttp.ClientTimeout(total=request_timeout_seconds)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||
await tqdm.gather(*tasks, desc="Evaluating")
|
||||
|
||||
|
||||
@@ -39,11 +39,18 @@ def run_gsm8k_eval(eval_config: dict, server_url: str) -> dict:
|
||||
host = f"http://{host}"
|
||||
|
||||
# Run GSM8K evaluation
|
||||
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
|
||||
if current_platform.is_rocm():
|
||||
request_timeout_seconds = eval_config.get(
|
||||
"rocm_request_timeout_seconds", request_timeout_seconds
|
||||
)
|
||||
|
||||
results = evaluate_gsm8k(
|
||||
num_questions=eval_config["num_questions"],
|
||||
num_shots=eval_config["num_fewshot"],
|
||||
host=host,
|
||||
port=port,
|
||||
request_timeout_seconds=request_timeout_seconds,
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -90,6 +97,12 @@ def test_gsm8k_correctness(config_filename):
|
||||
print(f"Expected metric threshold: {eval_config['accuracy_threshold']}")
|
||||
print(f"Number of questions: {eval_config['num_questions']}")
|
||||
print(f"Number of few-shot examples: {eval_config['num_fewshot']}")
|
||||
request_timeout_seconds = eval_config.get("request_timeout_seconds", 600)
|
||||
if current_platform.is_rocm():
|
||||
request_timeout_seconds = eval_config.get(
|
||||
"rocm_request_timeout_seconds", request_timeout_seconds
|
||||
)
|
||||
print(f"Request timeout: {request_timeout_seconds}s")
|
||||
print(f"Server args: {' '.join(server_args)}")
|
||||
print(f"Environment variables: {env_dict}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user