TensorRT-LLMs/examples/mmlu.py
tburt-nv 6147452158
[https://nvbugs/4141427][chore] Add more details to LICENSE file (#9881)
Signed-off-by: Tyler Burt <195370667+tburt-nv@users.noreply.github.com>
2025-12-13 08:35:31 +08:00

480 lines
17 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2020 Dan Hendrycks
# SPDX-FileCopyrightText: Copyright (c) 2023 Deep Cognition and Language Research (DeCLaRe) Lab
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 and MIT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adapted from https://github.com/declare-lab/instruct-eval
Helper script to compare TRTLLM and HF models on the MMLU dataset.
Example usage:
mkdir data; wget https://people.eecs.berkeley.edu/~hendrycks/data.tar -O data/mmlu.tar
tar -xf data/mmlu.tar -C data && mv data/data data/mmlu
python mmlu.py --hf_model_dir <HF model path> --engine_dir <TRTLLM engine path> --test_trt_llm
python mmlu.py --hf_model_dir <HF model path> --engine_dir <TRTLLM engine path> --test_hf
"""
import argparse
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoTokenizer,
GenerationConfig)
from utils import (add_common_args, load_tokenizer, prepare_enc_dec_inputs,
read_is_enc_dec, read_model_name)
import tensorrt_llm
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner
if PYTHON_BINDINGS:
from tensorrt_llm.runtime import ModelRunnerCpp
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DTYPE_STR_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
RAND_SEED = 1234
def get_choices():
return ["A", "B", "C", "D"]
def get_subcategories():
return {
"abstract_algebra": ["math"],
"anatomy": ["health"],
"astronomy": ["physics"],
"business_ethics": ["business"],
"clinical_knowledge": ["health"],
"college_biology": ["biology"],
"college_chemistry": ["chemistry"],
"college_computer_science": ["computer science"],
"college_mathematics": ["math"],
"college_medicine": ["health"],
"college_physics": ["physics"],
"computer_security": ["computer science"],
"conceptual_physics": ["physics"],
"econometrics": ["economics"],
"electrical_engineering": ["engineering"],
"elementary_mathematics": ["math"],
"formal_logic": ["philosophy"],
"global_facts": ["other"],
"high_school_biology": ["biology"],
"high_school_chemistry": ["chemistry"],
"high_school_computer_science": ["computer science"],
"high_school_european_history": ["history"],
"high_school_geography": ["geography"],
"high_school_government_and_politics": ["politics"],
"high_school_macroeconomics": ["economics"],
"high_school_mathematics": ["math"],
"high_school_microeconomics": ["economics"],
"high_school_physics": ["physics"],
"high_school_psychology": ["psychology"],
"high_school_statistics": ["math"],
"high_school_us_history": ["history"],
"high_school_world_history": ["history"],
"human_aging": ["health"],
"human_sexuality": ["culture"],
"international_law": ["law"],
"jurisprudence": ["law"],
"logical_fallacies": ["philosophy"],
"machine_learning": ["computer science"],
"management": ["business"],
"marketing": ["business"],
"medical_genetics": ["health"],
"miscellaneous": ["other"],
"moral_disputes": ["philosophy"],
"moral_scenarios": ["philosophy"],
"nutrition": ["health"],
"philosophy": ["philosophy"],
"prehistory": ["history"],
"professional_accounting": ["other"],
"professional_law": ["law"],
"professional_medicine": ["health"],
"professional_psychology": ["psychology"],
"public_relations": ["politics"],
"security_studies": ["politics"],
"sociology": ["culture"],
"us_foreign_policy": ["politics"],
"virology": ["health"],
"world_religions": ["philosophy"],
}
def get_categories():
return {
"STEM": [
"physics",
"chemistry",
"biology",
"computer science",
"math",
"engineering",
],
"humanities": ["history", "philosophy", "law"],
"social sciences": [
"politics",
"culture",
"economics",
"geography",
"psychology",
],
"other (business, health, misc.)": ["other", "business", "health"],
}
def format_subject(subject):
line = subject.split("_")
s = ""
for entry in line:
s += " " + entry
return s
def format_example(df, idx, include_answer=True):
prompt = df.iloc[idx, 0]
k = df.shape[1] - 2
for j in range(k):
prompt += "\n{}. {}".format(get_choices()[j], df.iloc[idx, j + 1])
prompt += "\nAnswer:"
if include_answer:
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
return prompt
def gen_prompt(train_df, subject, k=-1):
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
format_subject(subject))
if k == -1:
k = train_df.shape[0]
for i in range(k):
prompt += format_example(train_df, i)
return prompt
def evaluate(args, subject, pipeline, dev_df, test_df):
rank = tensorrt_llm.mpi_rank()
cors = []
all_probs = []
for i in range(test_df.shape[0]):
if i >= args.max_ite:
break
# get prompt and make sure it fits
k = args.ntrain
prompt_end = format_example(test_df, i, include_answer=False)
train_prompt = gen_prompt(dev_df, subject, k)
prompt = train_prompt + prompt_end
while not pipeline.check_valid_length(prompt) and k > 0:
k -= 1
train_prompt = gen_prompt(dev_df, subject, k)
prompt = train_prompt + prompt_end
label = test_df.iloc[i, test_df.shape[1] - 1]
pred = pipeline(prompt)
if rank == 0:
probs = [0 for _ in get_choices()]
cor = pred.strip().startswith(label)
cors.append(cor)
all_probs.append(probs)
if rank == 0:
acc = np.mean(cors)
cors = np.array(cors)
all_probs = np.array(all_probs)
print("Average accuracy {:.3f} - {}".format(acc, subject))
return cors, acc, all_probs
else:
return None, 0, None
def get_tokenizer(ckpt_path, max_seq_len):
print(f"Initializing tokenizer from {ckpt_path}")
tokenizer = AutoTokenizer.from_pretrained(
ckpt_path,
model_max_length=max_seq_len,
padding_side="left",
trust_remote_code=True,
)
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
class Pipeline:
def __init__(self, tokenizer, model, model_name, pad_id, end_id,
max_attention_window_size, is_enc_dec, hf_model_dir,
engine_dir):
self.tokenizer = tokenizer
self.model = model
self.model_name = model_name
self.pad_id = pad_id
self.end_id = end_id
self.max_attention_window_size = max_attention_window_size
self.output_len = 2
self.is_enc_dec = is_enc_dec
self.decoder_start_token_id = None
self.engine_dir = engine_dir
if self.is_enc_dec:
self.decoder_start_token_id = AutoConfig.from_pretrained(
hf_model_dir).decoder_start_token_id
def __call__(self, prompt):
rank = tensorrt_llm.mpi_rank()
# Run the model in batch size 1 and beam size 1
inputs = self.tokenizer.encode(prompt, return_tensors="pt").squeeze(0)
batch_input_ids = [inputs]
# For multi-choice tasks like MMLU, we don't need to adjust following parameters
output_len = self.output_len
top_k = 1
top_p = 0.0
input_lengths = [x.size(0) for x in batch_input_ids]
with torch.no_grad():
if isinstance(self.model, nn.Module):
# Left padding for HF
max_length = max(input_lengths)
paddings = [
torch.ones(max_length - l, dtype=torch.int32) * self.pad_id
for l in input_lengths
]
batch_input_ids = [
torch.cat([pad, x])
for x, pad in zip(batch_input_ids, paddings)
]
batch_input_ids = torch.stack(batch_input_ids)
batch_input_ids = batch_input_ids.cuda()
if self.is_enc_dec:
batch_decoder_input_ids = torch.IntTensor(
[[self.decoder_start_token_id]]).to('cuda')
batch_decoder_input_ids = batch_decoder_input_ids.repeat(
(batch_input_ids.shape[0], 1))
with torch.no_grad():
# Use default temperature and top_k
outputs = self.model.generate(
batch_input_ids,
max_new_tokens=output_len,
top_k=top_k,
decoder_input_ids=batch_decoder_input_ids
if self.is_enc_dec else None)
if not self.is_enc_dec:
output_ids = outputs[0, input_lengths[0]:]
else:
output_ids = outputs[0]
elif isinstance(self.model, ModelRunnerCpp) or isinstance(
self.model, ModelRunner):
if self.is_enc_dec:
encoder_input_ids, encoder_input_features, encoder_output_lengths, decoder_input_ids = prepare_enc_dec_inputs(
batch_input_ids, self.model_name, self.engine_dir, None)
outputs = self.model.generate(
batch_input_ids=decoder_input_ids
if self.is_enc_dec else batch_input_ids,
encoder_input_ids=encoder_input_ids
if self.is_enc_dec else None,
encoder_input_features=encoder_input_features
if self.is_enc_dec else None,
encoder_output_lengths=encoder_output_lengths
if self.is_enc_dec else None,
max_new_tokens=output_len,
max_attention_window_size=self.max_attention_window_size,
end_id=self.end_id,
pad_id=self.pad_id,
top_k=top_k,
top_p=top_p,
)
torch.cuda.synchronize()
if rank == 0:
if not self.is_enc_dec:
output_ids = outputs[0, 0, input_lengths[0]:]
else:
output_ids = outputs[0, 0]
if rank == 0:
return self.tokenizer.decode(output_ids, skip_special_tokens=True)
else:
return None
def check_valid_length(self, prompt):
if isinstance(self.model, nn.Module):
return True
input_len = len(self.tokenizer.encode(prompt))
return input_len <= self.model.max_input_len and input_len + self.output_len <= self.model.max_seq_len
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
type=str,
default="data/mmlu",
help=("Path to the data directory. If not available, "
"download https://people.eecs.berkeley.edu/~hendrycks/data.tar"),
)
parser.add_argument("--ntrain", type=int, default=5)
parser.add_argument("--max_input_length", type=int, default=2048)
parser.add_argument("--test_trt_llm", action="store_true")
parser.add_argument("--test_hf", action="store_true")
parser.add_argument('--check_accuracy', action='store_true')
parser.add_argument('--accuracy_threshold', type=float, default=30)
parser.add_argument('--max_ite', type=int, default=10000000)
parser = add_common_args(parser)
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.tokenizer_dir is None:
args.tokenizer_dir = args.hf_model_dir
random.seed(RAND_SEED)
np.random.seed(RAND_SEED)
runtime_rank = tensorrt_llm.mpi_rank()
os.path.dirname(os.path.abspath(__file__))
data_fullpath = os.path.join(args.data_dir, "test")
subjects = sorted([
f.split("_test.csv")[0] for f in os.listdir(data_fullpath)
if "_test.csv" in f
])
all_cors = []
subcat_cors = {
subcat: []
for subcat_lists in get_subcategories().values()
for subcat in subcat_lists
}
cat_cors = {cat: [] for cat in get_categories()}
# different handling if encoder-decoder models
is_enc_dec = read_is_enc_dec(
args.engine_dir if not args.test_hf else args.hf_model_dir,
args.test_hf)
model_name, model_version = read_model_name(
(args.engine_dir if not is_enc_dec else os.path.join(
args.engine_dir, 'encoder'))
if not args.test_hf else args.hf_model_dir, args.test_hf)
tokenizer, pad_id, end_id = load_tokenizer(
tokenizer_dir=args.tokenizer_dir,
vocab_file=args.vocab_file,
model_name=model_name,
model_version=model_version,
)
if args.test_trt_llm:
assert not args.test_hf, "Cannot test both TRT-LLM and HF"
runner_cls = ModelRunner if not PYTHON_BINDINGS else ModelRunnerCpp
runner_kwargs = {}
if PYTHON_BINDINGS:
runner_kwargs.update(max_beam_width=1)
runner_kwargs.update(
is_enc_dec=is_enc_dec,
max_tokens_in_paged_kv_cache=args.max_tokens_in_paged_kv_cache,
kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse,
kv_cache_free_gpu_memory_fraction=args.
kv_cache_free_gpu_memory_fraction,
cross_kv_cache_fraction=args.cross_kv_cache_fraction
if is_enc_dec else None,
enable_chunked_context=args.enable_chunked_context,
multi_block_mode=args.multi_block_mode)
model = runner_cls.from_dir(engine_dir=args.engine_dir,
rank=runtime_rank,
**runner_kwargs)
else:
assert args.test_hf, "Must test either TRT-LLM or HF"
if 'GLM' in model_name and model_version == 'glm':
auto_model_cls = AutoModelForSeq2SeqLM
elif 'GLM' in model_name and model_version == 'chatglm':
auto_model_cls = AutoModel
elif is_enc_dec:
auto_model_cls = AutoModelForSeq2SeqLM
else:
auto_model_cls = AutoModelForCausalLM
model = auto_model_cls.from_pretrained(
args.hf_model_dir,
trust_remote_code=True,
dtype=DTYPE_STR_MAPPING[args.hf_data_type],
device_map="auto" if args.hf_device_map_auto else None,
)
if not args.hf_device_map_auto:
model.cuda()
if model_name == "qwen":
model.generation_config = GenerationConfig.from_pretrained(
args.hf_model_dir, trust_remote_code=True)
pipeline = Pipeline(tokenizer, model, model_name, pad_id, end_id,
args.max_attention_window_size, is_enc_dec,
args.hf_model_dir, args.engine_dir)
for subject in tqdm(subjects):
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev",
subject + "_dev.csv"),
header=None)[:args.ntrain]
test_df = pd.read_csv(os.path.join(args.data_dir, "test",
subject + "_test.csv"),
header=None)
cors, acc, probs = evaluate(args, subject, pipeline, dev_df, test_df)
subcats = get_subcategories()[subject]
for subcat in subcats:
subcat_cors[subcat].append(cors)
for key in get_categories().keys():
if subcat in get_categories()[key]:
cat_cors[key].append(cors)
all_cors.append(cors)
if runtime_rank == 0:
for subcat in subcat_cors:
acc = np.mean(np.concatenate(subcat_cors[subcat])) * 100
print(f"Average accuracy {acc:.2f} - {subcat}")
for cat in cat_cors:
acc = np.mean(np.concatenate(cat_cors[cat])) * 100
print(f"Average accuracy {acc:.2f} - {cat}")
weighted_acc = np.mean(np.concatenate(all_cors)) * 100
print(f"MMLU weighted average accuracy: {weighted_acc:.2f}")
if args.check_accuracy:
assert weighted_acc >= args.accuracy_threshold, f"Expected accuracy >= {args.accuracy_threshold} while got {weighted_acc}"
return weighted_acc
if __name__ == "__main__":
main()