TensorRT-LLMs/tensorrt_llm/evaluate/lm_eval.py
Enwei Zhu 3fa19ffa4e
test [TRTLLM-4477,TRTLLM-4481]: Accuracy test improvement (Part 3.5): Support GSM8K and GPQA (#3483)
* add gsm8k

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix gsm8k

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* add gpqa

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* conditional import lm_eval

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* gpqa in lm_eval

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* system prompt

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* shuffle

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* update AA prompt and regex

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* revert AA prompt and regex

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* integration to tests

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* add DS-R1

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix and clean

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* update tests

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* update

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* clean up

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* free_gpu_memory_fraction=0.8

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

* fix

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>

---------

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
2025-04-22 07:38:16 +08:00

297 lines
11 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
from contextlib import contextmanager
from typing import Dict, Iterable, List, Optional, Tuple, Union
import click
import numpy as np
from tqdm import tqdm
import tensorrt_llm.profiler as profiler
try:
from lm_eval.api.model import TemplateLM
except ImportError:
TemplateLM = object
from .._torch import LLM as PyTorchLLM
from ..llmapi import LLM, RequestOutput
from ..logger import logger
from ..sampling_params import SamplingParams
from .interface import Evaluator
class LmEvalWrapper(TemplateLM):
def __init__(self,
llm: Union[LLM, PyTorchLLM],
sampling_params: Optional[SamplingParams] = None):
super().__init__()
self.llm = llm
self.sampling_params = sampling_params
@property
def eot_token_id(self) -> int:
return self.llm.tokenizer.eos_token_id
def apply_chat_template(self,
chat_history: List[Dict[str, str]],
add_generation_prompt: bool = True) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
"""
return self.llm.tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
@property
def tokenizer_name(self) -> str:
return self.llm.tokenizer.name_or_path.replace("/", "__")
def tok_encode(self, string: str, **kwargs) -> List[int]:
return self.llm.tokenizer.encode(string, **kwargs)
def _loglikelihood_tokens(self, requests,
**kwargs) -> List[Tuple[float, bool]]:
raise NotImplementedError()
def loglikelihood_rolling(self,
requests,
disable_tqdm: bool = False) -> List[float]:
raise NotImplementedError()
def _get_sampling_params(self, gen_kwargs: dict) -> SamplingParams:
params_mapping = {
"temperature": "temperature",
"top_p": "top_p",
"max_gen_toks": "max_tokens",
"until": "stop",
}
if self.sampling_params is None:
sampling_params = SamplingParams()
else:
sampling_params = copy.deepcopy(self.sampling_params)
for lm_eval_key, trtllm_key in params_mapping.items():
value = gen_kwargs.pop(lm_eval_key, None)
if value is not None:
setattr(sampling_params, trtllm_key, value)
return sampling_params
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
profiler.start("trtllm exec")
outputs = []
for request in tqdm(requests,
desc="Submitting requests",
disable=disable_tqdm):
prompt, gen_kwargs = request.args
sampling_params = self._get_sampling_params(gen_kwargs)
output = self.llm.generate_async(prompt,
sampling_params=sampling_params)
outputs.append(output)
for output in tqdm(outputs,
desc="Fetching responses",
disable=disable_tqdm):
output.result()
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
profiler.reset("trtllm exec")
return [output.outputs[0].text for output in outputs]
class LmEvalEvaluator(Evaluator):
def __init__(self,
task_name: str,
dataset_path: str = None,
num_samples: Optional[int] = None,
random_seed: int = 0,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None):
try:
import lm_eval
except ImportError as e:
raise ImportError(
f"Evaluation task {self.__class__.__name__} requires `lm_eval`. "
"Please install the package first, e.g., `pip install lm_eval`."
) from e
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
self.task_name = task_name
self.dataset_path = dataset_path
self.num_samples = num_samples
task_manager = lm_eval.tasks.TaskManager(
include_path=f"{os.path.dirname(__file__)}/lm_eval_tasks")
with self._patch_lm_eval():
self.task_dict = lm_eval.tasks.get_task_dict(
task_name, task_manager=task_manager)
# Few-shot random seed
self.task_dict[self.task_name].set_fewshot_seed(random_seed)
# Shuffle dataset
data = self.task_dict[self.task_name].dataset
for split in data.keys():
data[split] = data[split].shuffle(random_seed)
@contextmanager
def _patch_lm_eval(self):
if self.dataset_path is None:
yield
return
import lm_eval
self._task_config_post_init = lm_eval.api.task.TaskConfig.__post_init__
def _patched(task_config, *args, **kwargs):
task_config.dataset_path = self.dataset_path
self._task_config_post_init(task_config, *args, **kwargs)
lm_eval.api.task.TaskConfig.__post_init__ = _patched
try:
yield
finally:
lm_eval.api.task.TaskConfig.__post_init__ = self._task_config_post_init
def generate_samples(self) -> Iterable[tuple]:
raise NotImplementedError()
def compute_score(self, outputs: List[RequestOutput], references: List[str],
*auxiliaries) -> float:
raise NotImplementedError()
def evaluate(self,
llm: Union[LLM, PyTorchLLM],
sampling_params: Optional[SamplingParams] = None) -> float:
import lm_eval
results = lm_eval.evaluate(lm=LmEvalWrapper(llm, sampling_params),
task_dict=self.task_dict,
limit=self.num_samples,
apply_chat_template=self.apply_chat_template,
system_instruction=self.system_prompt)
# Normalize scores to range 0~100
scores = results["results"][self.task_name]
for metric in scores.keys():
if isinstance(scores[metric], (float, int)):
scores[metric] *= 100
logger.info(
f"lm-eval {self.task_name} results (scores normalized to range 0~100):\n{lm_eval.utils.make_table(results)}"
)
average_acc = np.mean(
[acc for m, acc in scores.items() if "_stderr" not in m])
logger.info(
f"lm-eval {self.task_name} average accuracy: {average_acc:.2f}")
return average_acc
@classmethod
def command_harness(cls, ctx, **kwargs):
llm: Union[LLM, PyTorchLLM] = ctx.obj
evaluator = cls(dataset_path=kwargs.pop("dataset_path", None),
num_samples=kwargs.pop("num_samples", None),
random_seed=kwargs.pop("random_seed", 0),
apply_chat_template=kwargs.pop("apply_chat_template",
False),
system_prompt=kwargs.pop("system_prompt", None))
sampling_params = SamplingParams(
max_tokens=kwargs.pop("max_output_length"),
truncate_prompt_tokens=kwargs.pop("max_input_length"))
evaluator.evaluate(llm, sampling_params)
llm.shutdown()
class GSM8K(LmEvalEvaluator):
def __init__(self, **kwargs):
super().__init__("gsm8k", **kwargs)
@click.command("gsm8k")
@click.option("--dataset_path", type=str, default=None)
@click.option("--num_samples", type=int, default=None)
@click.option("--random_seed", type=int, default=0)
@click.option("--apply_chat_template", is_flag=True, default=False)
@click.option("--system_prompt", type=Optional[str], default=None)
@click.option("--max_input_length", type=int, default=4096)
@click.option("--max_output_length", type=int, default=256)
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
GSM8K.command_harness(ctx, **kwargs)
class GPQADiamond(LmEvalEvaluator):
def __init__(self, **kwargs):
super().__init__("gpqa_diamond_cot_zeroshot_aa", **kwargs)
@click.command("gpqa_diamond")
@click.option("--dataset_path", type=str, default=None)
@click.option("--num_samples", type=int, default=None)
@click.option("--random_seed", type=int, default=0)
@click.option("--apply_chat_template", is_flag=True, default=False)
@click.option("--system_prompt", type=Optional[str], default=None)
@click.option("--max_input_length", type=int, default=4096)
@click.option("--max_output_length", type=int, default=32768)
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
GPQADiamond.command_harness(ctx, **kwargs)
class GPQAMain(LmEvalEvaluator):
def __init__(self, **kwargs):
super().__init__("gpqa_main_cot_zeroshot_aa", **kwargs)
@click.command("gpqa_main")
@click.option("--dataset_path", type=str, default=None)
@click.option("--num_samples", type=int, default=None)
@click.option("--random_seed", type=int, default=0)
@click.option("--apply_chat_template", is_flag=True, default=False)
@click.option("--system_prompt", type=Optional[str], default=None)
@click.option("--max_input_length", type=int, default=4096)
@click.option("--max_output_length", type=int, default=32768)
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
GPQAMain.command_harness(ctx, **kwargs)
class GPQAExtended(LmEvalEvaluator):
def __init__(self, **kwargs):
super().__init__("gpqa_extended_cot_zeroshot_aa", **kwargs)
@click.command("gpqa_extended")
@click.option("--dataset_path", type=str, default=None)
@click.option("--num_samples", type=int, default=None)
@click.option("--random_seed", type=int, default=0)
@click.option("--apply_chat_template", is_flag=True, default=False)
@click.option("--system_prompt", type=Optional[str], default=None)
@click.option("--max_input_length", type=int, default=4096)
@click.option("--max_output_length", type=int, default=32768)
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
GPQAExtended.command_harness(ctx, **kwargs)