TensorRT-LLMs/tensorrt_llm/evaluate/mmlu.py
tburt-nv 6147452158
[https://nvbugs/4141427][chore] Add more details to LICENSE file (#9881)
Signed-off-by: Tyler Burt <195370667+tburt-nv@users.noreply.github.com>
2025-12-13 08:35:31 +08:00

334 lines
13 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2020 Dan Hendrycks
# SPDX-FileCopyrightText: Copyright (c) 2023 Deep Cognition and Language Research (DeCLaRe) Lab
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 and MIT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
from typing import Any, Iterable, List, Optional, Union
import click
import numpy as np
import pandas as pd
from .. import LLM as PyTorchLLM
from .._tensorrt_engine import LLM
from ..llmapi import RequestOutput
from ..logger import logger
from ..sampling_params import SamplingParams
from .interface import Evaluator
class MMLU(Evaluator):
DATASET_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
CHOICES = ["A", "B", "C", "D"]
SUBJECT_TO_SUBCATEGORIES = {
"abstract_algebra": ["math"],
"anatomy": ["health"],
"astronomy": ["physics"],
"business_ethics": ["business"],
"clinical_knowledge": ["health"],
"college_biology": ["biology"],
"college_chemistry": ["chemistry"],
"college_computer_science": ["computer science"],
"college_mathematics": ["math"],
"college_medicine": ["health"],
"college_physics": ["physics"],
"computer_security": ["computer science"],
"conceptual_physics": ["physics"],
"econometrics": ["economics"],
"electrical_engineering": ["engineering"],
"elementary_mathematics": ["math"],
"formal_logic": ["philosophy"],
"global_facts": ["other"],
"high_school_biology": ["biology"],
"high_school_chemistry": ["chemistry"],
"high_school_computer_science": ["computer science"],
"high_school_european_history": ["history"],
"high_school_geography": ["geography"],
"high_school_government_and_politics": ["politics"],
"high_school_macroeconomics": ["economics"],
"high_school_mathematics": ["math"],
"high_school_microeconomics": ["economics"],
"high_school_physics": ["physics"],
"high_school_psychology": ["psychology"],
"high_school_statistics": ["math"],
"high_school_us_history": ["history"],
"high_school_world_history": ["history"],
"human_aging": ["health"],
"human_sexuality": ["culture"],
"international_law": ["law"],
"jurisprudence": ["law"],
"logical_fallacies": ["philosophy"],
"machine_learning": ["computer science"],
"management": ["business"],
"marketing": ["business"],
"medical_genetics": ["health"],
"miscellaneous": ["other"],
"moral_disputes": ["philosophy"],
"moral_scenarios": ["philosophy"],
"nutrition": ["health"],
"philosophy": ["philosophy"],
"prehistory": ["history"],
"professional_accounting": ["other"],
"professional_law": ["law"],
"professional_medicine": ["health"],
"professional_psychology": ["psychology"],
"public_relations": ["politics"],
"security_studies": ["politics"],
"sociology": ["culture"],
"us_foreign_policy": ["politics"],
"virology": ["health"],
"world_religions": ["philosophy"],
}
CATEGORY_TO_SUBCATEGORIES = {
"STEM": [
"physics",
"chemistry",
"biology",
"computer science",
"math",
"engineering",
],
"humanities": ["history", "philosophy", "law"],
"social sciences": [
"politics",
"culture",
"economics",
"geography",
"psychology",
],
"other (business, health, misc.)": ["other", "business", "health"],
}
def __init__(self,
dataset_path: Optional[str] = None,
num_samples: Optional[int] = None,
num_fewshot: int = 5,
random_seed: int = 0,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None):
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
if dataset_path is None:
dataset_path = self.dowload_dataset()
self.dataset_path = dataset_path
if num_samples is None:
self.num_samples_per_subject = None
else:
self.num_samples_per_subject = math.ceil(
num_samples / len(self.SUBJECT_TO_SUBCATEGORIES))
self.num_fewshot = num_fewshot
def dowload_dataset(self):
import os
import tarfile
from tempfile import TemporaryDirectory
import requests
self.tempdir = TemporaryDirectory()
workspace = self.tempdir.name
response = requests.get(self.DATASET_URL, timeout=60)
with open(f"{workspace}/data.tar", "wb") as f:
f.write(response.content)
with tarfile.open(f"{workspace}/data.tar") as tar:
for member in tar.getmembers():
member_path = os.path.abspath(f"{workspace}/{member.name}")
if not member_path.startswith(workspace):
raise ValueError(
f"Insecure member found in tar file: {member.name}")
tar.extract(member, path=workspace, filter=tarfile.data_filter)
return f"{workspace}/data"
def format_subject(self, subject):
line = subject.split("_")
s = ""
for entry in line:
s += " " + entry
return s
def format_example(self, df, idx, include_answer=True):
prompt = df.iloc[idx, 0]
k = df.shape[1] - 2
for j in range(k):
prompt += "\n{}. {}".format(self.CHOICES[j], df.iloc[idx, j + 1])
prompt += "\nAnswer:"
if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt
def gen_prompt(self, train_df, subject, k=-1):
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
self.format_subject(subject))
if k == -1:
k = train_df.shape[0]
for i in range(k):
prompt += self.format_example(train_df, i)
return prompt
def generate_samples(self) -> Iterable[tuple]:
for subject in self.SUBJECT_TO_SUBCATEGORIES.keys():
dev_df = pd.read_csv(f"{self.dataset_path}/dev/{subject}_dev.csv",
header=None)
train_prompt = self.gen_prompt(dev_df, subject, self.num_fewshot)
test_df = pd.read_csv(
f"{self.dataset_path}/test/{subject}_test.csv", header=None)
if self.num_samples_per_subject is not None and self.num_samples_per_subject < test_df.shape[
0]:
test_df = test_df.sample(self.num_samples_per_subject)
for i in range(test_df.shape[0]):
prompt_end = self.format_example(test_df,
i,
include_answer=False)
prompt = train_prompt + prompt_end
label = test_df.iloc[i, test_df.shape[1] - 1]
yield prompt, None, label, subject
def compute_score(self, outputs: List[RequestOutput], references: List[str],
subjects: List[str]) -> float:
subject_corrections = {
key: []
for key in self.SUBJECT_TO_SUBCATEGORIES.keys()
}
for output, ref, sub in zip(outputs, references, subjects):
correction = output.outputs[0].text.strip().startswith(ref)
subject_corrections[sub].append(correction)
subcategory_corrections = {
key: []
for subcats in self.SUBJECT_TO_SUBCATEGORIES.values()
for key in subcats
}
category_corrections = {
key: []
for key in self.CATEGORY_TO_SUBCATEGORIES.keys()
}
all_corrections = []
for sub, corrections in subject_corrections.items():
for subcat in self.SUBJECT_TO_SUBCATEGORIES[sub]:
subcategory_corrections[subcat].extend(corrections)
for cat, subcats in self.CATEGORY_TO_SUBCATEGORIES.items():
if subcat in subcats:
category_corrections[cat].extend(corrections)
all_corrections.extend(corrections)
for subject, corrections in subject_corrections.items():
acc = np.mean(corrections) * 100
logger.info(
f"Average accuracy {acc:.2f} ({len(corrections)}) - {subject}")
for subcat, corrections in subcategory_corrections.items():
acc = np.mean(corrections) * 100
logger.info(
f"Average accuracy {acc:.2f} ({len(corrections)}) - {subcat}")
for cat, corrections in category_corrections.items():
acc = np.mean(corrections) * 100
logger.info(
f"Average accuracy {acc:.2f} ({len(corrections)}) - {cat}")
weighted_acc = np.mean(all_corrections) * 100
logger.info(
f"MMLU weighted average accuracy: {weighted_acc:.2f} ({len(all_corrections)})"
)
return weighted_acc
@click.command("mmlu")
@click.option(
"--dataset_path",
type=str,
default=None,
help="The path to MMLU dataset. The commands to prepare the dataset: "
"wget https://people.eecs.berkeley.edu/~hendrycks/data.tar && tar -xf data.tar. "
"If unspecified, the dataset is downloaded automatically.")
@click.option(
"--num_samples",
type=int,
default=None,
help="Number of samples to run the evaluation; None means full dataset."
)
@click.option("--num_fewshot",
type=int,
default=5,
help="Number of fewshot.")
@click.option("--random_seed",
type=int,
default=0,
help="Random seed for dataset processing.")
@click.option("--apply_chat_template",
is_flag=True,
default=False,
help="Whether to apply chat template.")
@click.option(
"--chat_template_kwargs",
type=str,
default=None,
callback=lambda ctx, param, value: json.loads(value) if value else None,
help=
'Chat template kwargs as JSON string, e.g., \'{"thinking_budget": 0}\'')
@click.option("--system_prompt",
type=str,
default=None,
help="System prompt.")
@click.option("--max_input_length",
type=int,
default=4094,
help="Maximum prompt length.")
@click.option("--max_output_length",
type=int,
default=2,
help="Maximum generation length.")
@click.option("--check_accuracy", is_flag=True, default=False)
@click.option("--accuracy_threshold", type=float, default=30)
@click.pass_context
@staticmethod
def command(ctx, dataset_path: Optional[str], num_samples: int,
num_fewshot: int, random_seed: int, apply_chat_template: bool,
chat_template_kwargs: Optional[dict[str, Any]],
system_prompt: Optional[str], max_input_length: int,
max_output_length: int, check_accuracy: bool,
accuracy_threshold: float) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(
max_tokens=max_output_length,
truncate_prompt_tokens=max_input_length)
evaluator = MMLU(dataset_path,
num_samples=num_samples,
num_fewshot=num_fewshot,
random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
accuracy = evaluator.evaluate(llm, sampling_params)
llm.shutdown()
if check_accuracy:
logger.warning(
"The --check_accuracy flag is not expected to be used anymore. "
"It is being used by some legacy accuracy tests that call evaluation commands via subprocess. "
"New accuracy tests should use LLM API within the pytest process; please see `tests/integration/defs/accuracy/README.md`."
)
assert accuracy >= accuracy_threshold, f"Expected accuracy >= {accuracy_threshold}, but got {accuracy}."