mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* Update TensorRT-LLM --------- Co-authored-by: IbrahimAmin <ibrahimamin532@gmail.com> Co-authored-by: Fabian Joswig <fjosw@users.noreply.github.com> Co-authored-by: Pzzzzz <hello-cd.plus@hotmail.com> Co-authored-by: CoderHam <hemant@cohere.com> Co-authored-by: Konstantin Lopuhin <kostia.lopuhin@gmail.com>
130 lines
3.7 KiB
Python
130 lines
3.7 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# reference: https://github.com/OpenBMB/InfiniteBench/blob/main/src/eval_utils.py
|
|
|
|
import os
|
|
import sys
|
|
|
|
sys.path.append(os.path.dirname(__file__))
|
|
|
|
import json
|
|
from pathlib import Path
|
|
|
|
DATA_NAME_TO_PATH = {
|
|
# Retrieval tasks
|
|
"passkey": "passkey.jsonl",
|
|
"kv_retrieval": "kv_retrieval.jsonl",
|
|
}
|
|
|
|
DATA_NAME_TO_MAX_NEW_TOKENS = {
|
|
"passkey": 6,
|
|
"kv_retrieval": 50,
|
|
}
|
|
|
|
|
|
def iter_jsonl(fname, cnt=None):
|
|
i = 0
|
|
with open(fname, "r") as fin:
|
|
for line in fin:
|
|
if i == cnt:
|
|
break
|
|
yield json.loads(line)
|
|
i += 1
|
|
|
|
|
|
def load_json(fname):
|
|
return json.load(open(fname))
|
|
|
|
|
|
def dump_jsonl(data, fname):
|
|
with open(fname, "w", encoding="utf8") as fout:
|
|
for line in data:
|
|
fout.write(json.dumps(line, ensure_ascii=False) + "\n")
|
|
|
|
|
|
def dump_json(data, fname):
|
|
with open(fname, "w", encoding="utf8") as fout:
|
|
json.dump(data, fout, indent=2, ensure_ascii=False)
|
|
|
|
|
|
def load_data(data_name: str, data_dir: str = "../data/InfiniteBench/"):
|
|
path = DATA_NAME_TO_PATH[data_name]
|
|
fname = Path(data_dir, path)
|
|
return list(iter_jsonl(fname))
|
|
|
|
|
|
def create_prompt(eg: dict, data_name: str, data_dir) -> str:
|
|
"""
|
|
Create prompt for a given example.
|
|
|
|
Args:
|
|
eg: example dict
|
|
data_name: name of the dataset/task
|
|
"""
|
|
data_dir = Path(data_dir)
|
|
|
|
templates = {
|
|
"passkey":
|
|
"There is an important info hidden inside a lot of irrelevant text. Find it and memorize it. I will quiz you about the important information.\n\n{context}\n\n{input}\n\nThe pass key is", # noqa
|
|
"kv_retrieval":
|
|
"Extract the value corresponding to the specified key in the JSON object below.\n\n{context}\n\n{input}", # noqa
|
|
}
|
|
|
|
if "content" in eg:
|
|
content = eg["content"]
|
|
del eg["content"]
|
|
eg["context"] = content
|
|
|
|
format_dict = {
|
|
"context": eg["context"],
|
|
"input": eg["input"],
|
|
}
|
|
prompt = templates[data_name].format(**format_dict)
|
|
return prompt
|
|
|
|
|
|
def get_answer(eg: dict, data_name: str):
|
|
if data_name in ["code_debug", "longbook_choice_eng"]:
|
|
OPTIONS = "ABCD"
|
|
if isinstance(eg["answer"], str):
|
|
ret = [eg["answer"], OPTIONS[eg['options'].index(eg["answer"])]]
|
|
elif isinstance(eg["answer"], list):
|
|
if len(eg["answer"]) == 1:
|
|
ret = [
|
|
eg["answer"][0],
|
|
OPTIONS[eg['options'].index(eg["answer"][0])]
|
|
]
|
|
elif len(eg["answer"]) == 2 and eg["answer"][1] in [
|
|
'A', 'B', 'C', 'D'
|
|
]:
|
|
ret = eg['answer']
|
|
else:
|
|
raise ValueError
|
|
else:
|
|
raise ValueError
|
|
return ret
|
|
|
|
return eg["answer"]
|
|
|
|
|
|
def truncate_input(input, max_length, manner="middle"):
|
|
if len(input) <= max_length:
|
|
return input
|
|
if manner == "middle":
|
|
return input[0:max_length // 2] + input[-max_length // 2:]
|
|
else:
|
|
return None
|