# MIT License # Copyright (c) 2023 OpenBMB # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # reference: https://github.com/OpenBMB/InfiniteBench/blob/main/src/compute_scores.py import json import re import string from collections import Counter from pathlib import Path from tqdm import tqdm from .args import parse_args def normalize_answer(s: str) -> str: """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def f1_score(prediction, ground_truth) -> tuple[float, float, float]: common = Counter(prediction) & Counter(ground_truth) num_same = sum(common.values()) if num_same == 0: return 0, 0, 0 precision = 1.0 * num_same / len(prediction) recall = 1.0 * num_same / len(ground_truth) f1 = (2 * precision * recall) / (precision + recall) return f1, precision, recall def load_json(fname): return json.load(open(fname)) def iter_jsonl(fname, cnt=None): i = 0 with open(fname, "r", encoding="utf8") as fin: for line in fin: if line.strip() == "": # Skip empty lines continue if i == cnt: break if line.strip() == "": # Skip empty lines continue yield json.loads(line) i += 1 def first_int_match(prediction): pred_list = re.split("[^0-9]", prediction) pred_value = "" for item in pred_list: if item != "": pred_value = item break return pred_value def split_retrieval_answer(pred: str): for c in ["\n", ":", '"', "'", ".", ",", "?", "!", "{", "}"]: pred = pred.replace(c, " ") words = pred.split() return words def get_score_one_kv_retrieval(pred, label) -> bool: for c in ['\n', ':', '\"', '\'', '.', ',', '?', '!', '{', '}']: pred = pred.replace(c, ' ') words = pred.split() return label in words def get_score_one_passkey(pred, label) -> bool: if isinstance(label, list): label = label[0] return label == first_int_match(pred) def get_score_one(pred: str, label: str, task_name: str) -> float: """ Computes the score for one prediction. Returns one float (zero and one for boolean values). """ NAME_TO_SCORE_GETTER = { # Retrieve "kv_retrieval": get_score_one_kv_retrieval, "kv_retrieval_prefix": get_score_one_kv_retrieval, "kv_retrieval_both": get_score_one_kv_retrieval, "passkey": get_score_one_passkey, } assert task_name in NAME_TO_SCORE_GETTER, f"Invalid task name: {task_name}" score = NAME_TO_SCORE_GETTER[task_name](pred, label) return float(score) def get_labels(preds: list) -> list[str]: possible_label_keys = ["ground_truth", "label"] for label_key in possible_label_keys: if label_key in preds[0]: return [x.get(label_key, "XXXXXXXXXX") for x in preds] raise ValueError(f"Cannot find label in {preds[0]}") def get_preds(preds: list, data_name: str) -> list[str]: pred_strings = [] possible_pred_keys = ["prediction", "pred"] for pred in preds: this_pred = "NO PREDICTION" for pred_key in possible_pred_keys: if pred_key in pred: this_pred = pred[pred_key] break else: raise ValueError(f"Cannot find prediction in {pred}") pred_strings.append(this_pred) return pred_strings def get_score(labels: list, preds: list, data_name: str) -> float: """ Computes the average score for a task. """ assert len(labels) == len(preds) scores = [] for label, pred in tqdm(zip(labels, preds)): score = get_score_one(pred, label, data_name) scores.append(score) return sum(scores) / len(scores) def load_json(preds_path): assert preds_path.exists(), f"Predictions not found in: {preds_path}" print("Loading prediction results from", preds_path) return list(iter_jsonl(preds_path)) def compute_scores(preds, data_name: str): labels = get_labels(preds) preds = get_preds(preds, data_name) acc = get_score(labels, preds, data_name) return acc ALL_TASKS = [ "passkey", "kv_retrieval", ] if __name__ == "__main__": arguments = parse_args() tasks = [arguments.task] for task in tasks: preds_path = Path(arguments.preds_file) preds = load_json(preds_path) acc = compute_scores(preds, task) print(acc)