Support to export data in trtllm-eval

Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
This commit is contained in:
yuhangh 2025-12-17 07:24:47 +00:00
parent 6ab996d635
commit cbc67b7c76
6 changed files with 250 additions and 35 deletions

View File

@ -34,10 +34,14 @@ class CnnDailymail(Evaluator):
random_seed: int = 0,
rouge_path: Optional[str] = None,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None):
system_prompt: Optional[str] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
system_prompt=system_prompt,
dump_path=dump_path,
dump_as_text=dump_as_text)
if dataset_path is None:
dataset_path = "ccdv/cnn_dailymail"
self.data = datasets.load_dataset(dataset_path,
@ -111,12 +115,21 @@ class CnnDailymail(Evaluator):
type=int,
default=100,
help="Maximum generation length.")
@click.option("--dump_path",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: Optional[str], num_samples: int,
random_seed: int, rouge_path: Optional[str],
apply_chat_template: bool, system_prompt: Optional[str],
max_input_length: int, max_output_length: int) -> None:
max_input_length: int, max_output_length: int,
dump_path: Optional[str], dump_as_text: bool) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(
max_tokens=max_output_length,
@ -126,6 +139,8 @@ class CnnDailymail(Evaluator):
random_seed=random_seed,
rouge_path=rouge_path,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
system_prompt=system_prompt,
dump_path=dump_path,
dump_as_text=dump_as_text)
evaluator.evaluate(llm, sampling_params)
llm.shutdown()

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
import random
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Union
@ -35,7 +37,9 @@ class Evaluator(ABC):
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None):
chat_template_kwargs: Optional[dict[str, Any]] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
@ -43,6 +47,8 @@ class Evaluator(ABC):
self.fewshot_as_multiturn = fewshot_as_multiturn
self.system_prompt = system_prompt
self.chat_template_kwargs = chat_template_kwargs
self.dump_path = dump_path
self.dump_as_text = dump_as_text
@abstractmethod
def generate_samples(self) -> Iterable[tuple]:
@ -103,8 +109,18 @@ class Evaluator(ABC):
references.append(reference)
auxiliaries.append(aux)
results = []
task_id = 0
if self.dump_path:
self.dump_path = prepare_dump_path(self.dump_path)
logger.info(f"Dumping data to {self.dump_path}")
for output in tqdm(outputs, desc="Fetching responses"):
results.append(output.result())
res = output.result()
results.append(res)
dump_inference_result(self.dump_path, res, task_id,
self.dump_as_text,
getattr(llm, 'tokenizer', None))
task_id += 1
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
@ -116,3 +132,54 @@ class Evaluator(ABC):
@staticmethod
def command(ctx, *args, **kwargs) -> None:
raise NotImplementedError()
def prepare_dump_path(dump_path: str) -> str:
if dump_path:
if os.path.isdir(dump_path) or dump_path.endswith(os.sep):
dump_path = os.path.join(dump_path, "dumped_data.json")
os.makedirs(os.path.dirname(dump_path), exist_ok=True)
if os.path.exists(dump_path):
os.remove(dump_path)
return dump_path
def dump_inference_result(dump_path: str, result: RequestOutput, task_id: int,
dump_as_text: bool, tokenizer: Any):
if not dump_path:
return
try:
with open(dump_path, "a") as f:
input_ids = result.prompt_token_ids
output_ids = result.outputs[0].token_ids
if tokenizer is None:
logger.warning("Tokenizer not found, dumping raw token ids")
dump_as_text = False
if dump_as_text:
input_content = tokenizer.decode(input_ids)
output_content = tokenizer.decode(output_ids)
else:
input_content = input_ids
output_content = output_ids
if dump_as_text:
data = {
"task_id": task_id,
"input_text": input_content,
"output_text": output_content,
"input_lens": len(input_content),
"output_lens": len(output_content)
}
else:
data = {
"task_id": task_id,
"input_ids": input_ids,
"output_ids": output_ids,
"input_tokens": len(input_content),
"output_tokens": len(output_content)
}
f.write(json.dumps(data) + "\n")
except Exception as e:
logger.warning(f"Failed to dump data to {dump_path}: {e}")

View File

@ -36,13 +36,17 @@ class JsonModeEval(Evaluator):
num_samples: Optional[int] = None,
random_seed: int = 0,
apply_chat_template: bool = True,
system_prompt: Optional[str] = None):
system_prompt: Optional[str] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
if not apply_chat_template:
raise ValueError(
f"{self.__class__.__name__} requires apply_chat_template=True.")
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
system_prompt=system_prompt,
dump_path=dump_path,
dump_as_text=dump_as_text)
if dataset_path is None:
dataset_path = "NousResearch/json-mode-eval"
self.data = datasets.load_dataset(dataset_path,
@ -120,11 +124,20 @@ class JsonModeEval(Evaluator):
type=int,
default=512,
help="Maximum generation length.")
@click.option("--dump_path",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: Optional[str], num_samples: int,
random_seed: int, system_prompt: Optional[str],
max_input_length: int, max_output_length: int) -> None:
max_input_length: int, max_output_length: int,
dump_path: Optional[str], dump_as_text: bool) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(
max_tokens=max_output_length,
@ -133,6 +146,8 @@ class JsonModeEval(Evaluator):
num_samples=num_samples,
random_seed=random_seed,
apply_chat_template=True,
system_prompt=system_prompt)
system_prompt=system_prompt,
dump_path=dump_path,
dump_as_text=dump_as_text)
evaluator.evaluate(llm, sampling_params)
llm.shutdown()

View File

@ -39,7 +39,7 @@ from ..inputs.utils import apply_chat_template as trtllm_apply_chat_template
from ..llmapi import RequestOutput
from ..logger import logger
from ..sampling_params import SamplingParams
from .interface import Evaluator
from .interface import Evaluator, dump_inference_result, prepare_dump_path
# NOTE: lm_eval uses "<image>" as the default image placeholder
# https://github.com/EleutherAI/lm-evaluation-harness/blob/7f04db12d2f8e7a99a0830d99eb78130e1ba2122/lm_eval/models/hf_vlms.py#L25
@ -54,12 +54,16 @@ class LmEvalWrapper(TemplateLM):
streaming: bool = False,
chat_template_kwargs: Optional[dict[str, Any]] = None,
model_type: str | None = None,
is_force_single_image: bool = False):
is_force_single_image: bool = False,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
super().__init__()
self.llm = llm
self.sampling_params = sampling_params
self.streaming = streaming
self.chat_template_kwargs = chat_template_kwargs
self.dump_path = dump_path
self.dump_as_text = dump_as_text
@property
def eot_token_id(self) -> int:
@ -139,10 +143,16 @@ class LmEvalWrapper(TemplateLM):
results.append(output)
outputs = []
task_id = 0
for output in tqdm(results,
desc="Fetching responses",
disable=disable_tqdm):
outputs.append(output.result())
res = output.result()
outputs.append(res)
dump_inference_result(self.dump_path, res, task_id,
self.dump_as_text,
getattr(self.llm, 'tokenizer', None))
task_id += 1
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
@ -167,7 +177,9 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
max_images: int = 999,
chat_template_kwargs: Optional[dict[str, Any]] = None,
model_type: str | None = None,
is_force_single_image: bool = False):
is_force_single_image: bool = False,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
"""
Initialize the multimodal wrapper.
@ -176,8 +188,15 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
sampling_params: Parameters for text generation
streaming: Whether to use streaming generation
max_images: Maximum number of images per prompt (currently unlimited in TRT-LLM), set to 999 from lm_eval's default value.
chat_template_kwargs: Chat template kwargs as JSON string
dump_path: Path to dump data to ids for trtllm-bench usage
dump_as_text: Whether to dump data to text
"""
super().__init__(llm, sampling_params, streaming)
super().__init__(llm,
sampling_params,
streaming,
dump_path=dump_path,
dump_as_text=dump_as_text)
# NOTE: Required by lm_eval to identify this as a multimodal model
self.MULTIMODAL = True
@ -313,7 +332,11 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
for output in tqdm(results,
desc="Fetching responses",
disable=disable_tqdm):
outputs.append(output.result())
res = output.result()
outputs.append(res)
dump_inference_result(self.dump_path, res, task_id,
self.dump_as_text,
getattr(self.llm, 'tokenizer', None))
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
@ -334,7 +357,9 @@ class LmEvalEvaluator(Evaluator):
fewshot_as_multiturn: bool = False,
system_prompt: Optional[str] = None,
is_multimodal: bool = False,
chat_template_kwargs: Optional[dict[str, Any]] = None):
chat_template_kwargs: Optional[dict[str, Any]] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
try:
import lm_eval
except ImportError as e:
@ -353,7 +378,9 @@ class LmEvalEvaluator(Evaluator):
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
chat_template_kwargs=chat_template_kwargs,
dump_path=dump_path,
dump_as_text=dump_as_text)
self.task_name = task_name
self.dataset_path = dataset_path
self.num_samples = num_samples
@ -445,13 +472,20 @@ class LmEvalEvaluator(Evaluator):
is_force_single_image: bool = False) -> float:
import lm_eval
lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper
if self.dump_path:
self.dump_path = prepare_dump_path(self.dump_path)
logger.info(f"Dumping data to {self.dump_path}")
results = lm_eval.evaluate(
lm=lm_cls(llm,
sampling_params=sampling_params,
streaming=streaming,
chat_template_kwargs=self.chat_template_kwargs,
model_type=model_type,
is_force_single_image=is_force_single_image),
is_force_single_image=is_force_single_image,
dump_path=self.dump_path,
dump_as_text=self.dump_as_text),
task_dict=self.task_dict,
limit=self.num_samples,
apply_chat_template=self.apply_chat_template,
@ -491,7 +525,9 @@ class LmEvalEvaluator(Evaluator):
system_prompt=kwargs.pop("system_prompt", None),
is_multimodal=kwargs.pop("is_multimodal", False),
chat_template_kwargs=kwargs.pop("chat_template_kwargs",
None))
None),
dump_path=kwargs.pop("dump_path", None),
dump_as_text=kwargs.pop("dump_as_text", False))
sampling_params = SamplingParams(
max_tokens=kwargs.pop("max_output_length"),
truncate_prompt_tokens=kwargs.pop("max_input_length"),
@ -548,6 +584,14 @@ class GSM8K(LmEvalEvaluator):
type=int,
default=256,
help="Maximum generation length.")
@click.option("--dump_path",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -602,6 +646,14 @@ class GPQADiamond(LmEvalEvaluator):
type=int,
default=32768,
help="Maximum generation length.")
@click.option("--dump_path",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -652,6 +704,14 @@ class GPQAMain(LmEvalEvaluator):
type=int,
default=32768,
help="Maximum generation length.")
@click.option("--dump_path",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -702,6 +762,14 @@ class GPQAExtended(LmEvalEvaluator):
type=int,
default=32768,
help="Maximum generation length.")
@click.option("--dump_path",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -753,6 +821,14 @@ class MMMU(LmEvalEvaluator):
default=
512, # NOTE: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/mmmu/_template_yaml#L13
help="Maximum generation length.")
@click.option("--dump_path",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:

View File

@ -62,12 +62,16 @@ class LongBenchV2(Evaluator):
cot: bool = False,
no_context: bool = False,
rag: int = 0,
max_len: int = 128000,
max_input_length: int = 128000,
max_output_length: int = 32000,
output_dir: Optional[str] = None,
random_seed: int = 0,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None):
chat_template_kwargs: Optional[dict[str, Any]] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
"""Initialize LongBench v2 evaluator.
Args:
@ -81,17 +85,23 @@ class LongBenchV2(Evaluator):
cot: Enable Chain-of-Thought reasoning
no_context: Test without context (memorization test)
rag: Number of top retrieved contexts to use (0 to disable)
max_input_length: Maximum prompt length in tokens for truncation
max_len: Maximum length (input + output) in tokens
max_input_length: Maximum context length in tokens. If exceeds, the prompt will be truncated in the middle.
max_output_length: Maximum output length in tokens for truncation
output_dir: Directory to save evaluation results
random_seed: Random seed for reproducibility
apply_chat_template: Whether to apply model's chat template
system_prompt: System prompt to prepend
chat_template_kwargs: Chat template kwargs as JSON string
dump_path: Path to dump data to ids for trtllm-bench usage.
dump_as_text: Whether to dump data to text.
"""
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
chat_template_kwargs=chat_template_kwargs,
dump_path=dump_path,
dump_as_text=dump_as_text)
self.dataset_path = dataset_path
self.num_samples = num_samples
@ -103,7 +113,8 @@ class LongBenchV2(Evaluator):
self.no_context = no_context
self.rag = rag
self.output_dir = output_dir
self.max_input_length = max_input_length
# We need to minus max_output_length from max_len to reserve budget for output tokens.
self.max_input_length = min(max_input_length, max_len - max_output_length)
# Will be set during evaluation
self.tokenizer = None
@ -305,7 +316,6 @@ class LongBenchV2(Evaluator):
If the prompt exceeds max_input_length, it takes the first half and last half
to preserve both context beginning and end.
We need to minus max_output_length from max_len to reserve budget for output tokens.
Args:
prompt: The prompt string to truncate
@ -728,12 +738,16 @@ class LongBenchV2(Evaluator):
default=None,
help="System prompt.")
@click.option(
"--max_input_length",
"--max_len",
type=int,
default=128000,
default=1024000,
help=
"Maximum prompt length before apply chat template. If exceeds, the prompt will be truncated in the middle."
"Maximum length (input + output) in tokens which can be supported by the model."
)
@click.option("--max_input_length",
type=int,
default=128000,
help="Maximum context length in tokens. If exceeds, the prompt will be truncated in the middle.")
@click.option("--max_output_length",
type=int,
default=32000,
@ -755,6 +769,14 @@ class LongBenchV2(Evaluator):
type=float,
default=0.95,
help="Top p for sampling.")
@click.option("--dump_path",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: str, prompts_dir: Optional[str],
@ -763,9 +785,10 @@ class LongBenchV2(Evaluator):
cot: bool, no_context: bool, rag: int,
output_dir: Optional[str], random_seed: int,
apply_chat_template: bool, system_prompt: Optional[str],
max_input_length: int, max_output_length: int,
max_len: int, max_input_length: int, max_output_length: int,
chat_template_kwargs: Optional[dict[str, Any]],
temperature: float, top_p: float) -> None:
temperature: float, top_p: float,
dump_path: Optional[str], dump_as_text: bool) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(max_tokens=max_output_length,
@ -782,12 +805,16 @@ class LongBenchV2(Evaluator):
cot=cot,
no_context=no_context,
rag=rag,
max_len=max_len,
max_input_length=max_input_length,
max_output_length=max_output_length,
output_dir=output_dir,
random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
chat_template_kwargs=chat_template_kwargs,
dump_path=dump_path,
dump_as_text=dump_as_text)
evaluator.evaluate(llm, sampling_params)
llm.shutdown()

View File

@ -121,11 +121,15 @@ class MMLU(Evaluator):
random_seed: int = 0,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None):
chat_template_kwargs: Optional[dict[str, Any]] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
chat_template_kwargs=chat_template_kwargs,
dump_path=dump_path,
dump_as_text=dump_as_text)
if dataset_path is None:
dataset_path = self.dowload_dataset()
self.dataset_path = dataset_path
@ -302,6 +306,14 @@ class MMLU(Evaluator):
help="Maximum generation length.")
@click.option("--check_accuracy", is_flag=True, default=False)
@click.option("--accuracy_threshold", type=float, default=30)
@click.option("--dump_path",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: Optional[str], num_samples: int,
@ -309,7 +321,8 @@ class MMLU(Evaluator):
chat_template_kwargs: Optional[dict[str, Any]],
system_prompt: Optional[str], max_input_length: int,
max_output_length: int, check_accuracy: bool,
accuracy_threshold: float) -> None:
accuracy_threshold: float, dump_path: Optional[str],
dump_as_text: bool) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(
max_tokens=max_output_length,
@ -320,7 +333,9 @@ class MMLU(Evaluator):
random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
chat_template_kwargs=chat_template_kwargs,
dump_path=dump_path,
dump_as_text=dump_as_text)
accuracy = evaluator.evaluate(llm, sampling_params)
llm.shutdown()