TensorRT-LLMs/examples/multimodal/eval.py
石晓伟 548b5b7310
Update TensorRT-LLM (#2532)
* blossom-ci.yml: run vulnerability scan on blossom

* open source efb18c1256f8c9c3d47b7d0c740b83e5d5ebe0ec

---------

Co-authored-by: niukuo <6831097+niukuo@users.noreply.github.com>
Co-authored-by: pei0033 <59505847+pei0033@users.noreply.github.com>
Co-authored-by: Kyungmin Lee <30465912+lkm2835@users.noreply.github.com>
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2024-12-04 21:16:56 +08:00

177 lines
5.9 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import aiohttp
import datasets
import torch
from transformers import AutoProcessor, MllamaForConditionalGeneration
from utils import add_common_args
import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm import logger
from tensorrt_llm.runtime import MultimodalModelRunner
def prepare_prompts(task, data):
if task == 'lmms-lab/ai2d':
prompts = f"<|image|><|begin_of_text|> {data['question']}"
if prompts[-1] != '?':
prompts += '?'
for j, option in enumerate(data['options']):
prompts += f" ({j}) {option}"
prompts += "; answer: "
return prompts
os.environ["TOKENIZERS_PARALLELISM"] = "false"
parser = argparse.ArgumentParser()
parser = add_common_args(parser)
parser.add_argument('--test_trtllm',
action='store_true',
default=None,
help="Evaluate the TensorRT-LLM.")
parser.add_argument('--test_hf',
action='store_true',
default=None,
help="Evaluate the Huggingface.")
parser.add_argument('--max_ite', type=int, default=20)
parser.add_argument('--eval_task',
type=str,
default='lmms-lab/ai2d',
choices=[
'lmms-lab/ai2d',
])
parser.add_argument(
'--accuracy_threshold',
type=float,
default=None,
help=
'used to check the accuracy of test_trtllm. Should be between 0 and 100.')
parser.add_argument(
'--dataset_dir',
type=str,
default=None,
help="The local directory of the dataset for evaluation; "
"will download the dataset from huggingface hub if not specified.")
parser.add_argument(
'--dataset_cache_dir',
type=str,
default=None,
help="The local cache directory for dataset; "
"will use `~/.cache/huggingface/datasets` if not specified.")
args = parser.parse_args()
logger.set_level(args.log_level)
runtime_rank = tensorrt_llm.mpi_rank()
dataset = datasets.load_dataset(args.dataset_dir,
storage_options={
'client_kwargs': {
'timeout':
aiohttp.ClientTimeout(total=3600)
}
},
cache_dir=args.dataset_cache_dir,
split='test')
processor = AutoProcessor.from_pretrained(args.hf_model_dir)
hf_model = None
if args.test_hf:
profiler.start('load HF model')
hf_model = MllamaForConditionalGeneration.from_pretrained(
args.hf_model_dir,
torch_dtype=torch.bfloat16,
device_map="auto",
)
profiler.stop('load HF model')
logger.info(
f'Load HF model takes: {profiler.elapsed_time_in_sec("load HF model")} sec'
)
trtllm_model = None
if args.test_trtllm:
profiler.start('load TensorRT-LLM model')
trtllm_model = MultimodalModelRunner(args)
profiler.stop('load TensorRT-LLM model')
logger.info(
f'Load TensorRT-LLM model takes: {profiler.elapsed_time_in_sec("load TensorRT-LLM model")} sec'
)
if trtllm_model or hf_model:
trtllm_correct = 0 if trtllm_model else None
hf_correct = 0 if hf_model else None
for i in range(args.max_ite):
logger.debug(f"Ite: {i:3d}")
data = dataset[i]
prompts = prepare_prompts(args.eval_task, data)
answer = data['answer']
image = data['image']
hf_result = None
if hf_model:
profiler.start('hf')
inputs = processor(
image,
prompts,
return_tensors="pt",
).to(hf_model.device)
input_length = inputs.input_ids.shape[-1]
hf_output = hf_model.generate(**inputs, max_new_tokens=1)
hf_result = processor.decode(hf_output[0][input_length:])
if answer == hf_result:
hf_correct += 1
profiler.stop('hf')
trtllm_result = None
if trtllm_model:
profiler.start('tensorrt_llm')
input_text, output_text = trtllm_model.run(prompts,
image,
max_new_tokens=1)
if runtime_rank == 0:
trtllm_result = output_text[0][0]
if answer == trtllm_result:
trtllm_correct += 1
profiler.stop('tensorrt_llm')
if runtime_rank == 0:
logger.debug(f"prompts: {prompts}")
logger.debug(f"reference answer: {answer}")
if hf_result:
logger.debug(f"HF's answer: {hf_result}")
if trtllm_result:
logger.debug(f"TRT-LLM's answer: {trtllm_result}")
if runtime_rank == 0:
logger.info(f"total iterations: {args.max_ite}")
if hf_correct is not None:
logger.info(
f"HF's accuracy: {100 * hf_correct / args.max_ite:4.2f}%")
if trtllm_correct is not None:
logger.info(
f"TRT-LLM's accuracy: {100 * trtllm_correct / args.max_ite:4.2f}%"
)
else:
logger.info("Neither enable test_trtllm nor enable test_hf")