Use output_dir and save both of prompt ids and text

Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
This commit is contained in:
yuhangh 2025-12-17 08:56:44 +00:00
parent cbc67b7c76
commit c9e518cd24
6 changed files with 116 additions and 174 deletions

View File

@ -35,13 +35,11 @@ class CnnDailymail(Evaluator):
rouge_path: Optional[str] = None,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
output_dir: Optional[str] = None):
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
dump_path=dump_path,
dump_as_text=dump_as_text)
output_dir=output_dir)
if dataset_path is None:
dataset_path = "ccdv/cnn_dailymail"
self.data = datasets.load_dataset(dataset_path,
@ -115,21 +113,17 @@ class CnnDailymail(Evaluator):
type=int,
default=100,
help="Maximum generation length.")
@click.option("--dump_path",
@click.option("--output_dir",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
help="Directory to save results.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: Optional[str], num_samples: int,
random_seed: int, rouge_path: Optional[str],
apply_chat_template: bool, system_prompt: Optional[str],
max_input_length: int, max_output_length: int,
dump_path: Optional[str], dump_as_text: bool) -> None:
output_dir: Optional[str]) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(
max_tokens=max_output_length,
@ -140,7 +134,6 @@ class CnnDailymail(Evaluator):
rouge_path=rouge_path,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
dump_path=dump_path,
dump_as_text=dump_as_text)
output_dir=output_dir)
evaluator.evaluate(llm, sampling_params)
llm.shutdown()

View File

@ -38,8 +38,7 @@ class Evaluator(ABC):
fewshot_as_multiturn: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
output_dir: Optional[str] = None):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
@ -47,8 +46,8 @@ class Evaluator(ABC):
self.fewshot_as_multiturn = fewshot_as_multiturn
self.system_prompt = system_prompt
self.chat_template_kwargs = chat_template_kwargs
self.dump_path = dump_path
self.dump_as_text = dump_as_text
self.output_dir = output_dir
self.inference_results = []
@abstractmethod
def generate_samples(self) -> Iterable[tuple]:
@ -110,17 +109,17 @@ class Evaluator(ABC):
auxiliaries.append(aux)
results = []
task_id = 0
if self.dump_path:
self.dump_path = prepare_dump_path(self.dump_path)
logger.info(f"Dumping data to {self.dump_path}")
self.inference_results = []
for output in tqdm(outputs, desc="Fetching responses"):
res = output.result()
results.append(res)
dump_inference_result(self.dump_path, res, task_id,
self.dump_as_text,
getattr(llm, 'tokenizer', None))
collect_inference_result(self.inference_results, res, task_id)
task_id += 1
if self.output_dir:
dump_inference_results(self.output_dir, self.inference_results,
getattr(llm, 'tokenizer', None))
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
@ -134,52 +133,59 @@ class Evaluator(ABC):
raise NotImplementedError()
def prepare_dump_path(dump_path: str) -> str:
if dump_path:
if os.path.isdir(dump_path) or dump_path.endswith(os.sep):
dump_path = os.path.join(dump_path, "dumped_data.json")
os.makedirs(os.path.dirname(dump_path), exist_ok=True)
if os.path.exists(dump_path):
os.remove(dump_path)
return dump_path
def collect_inference_result(results_list: List[dict], result: RequestOutput,
task_id: int):
input_ids = result.prompt_token_ids
output_ids = result.outputs[0].token_ids
results_list.append({
"task_id": task_id,
"input_ids": input_ids,
"output_ids": output_ids
})
def dump_inference_result(dump_path: str, result: RequestOutput, task_id: int,
dump_as_text: bool, tokenizer: Any):
if not dump_path:
def dump_inference_results(output_dir: str, results_list: List[dict],
tokenizer: Any):
if not output_dir:
return
os.makedirs(output_dir, exist_ok=True)
# Dump token ids
ids_path = os.path.join(output_dir, "dumped_ids.json")
try:
with open(dump_path, "a") as f:
input_ids = result.prompt_token_ids
output_ids = result.outputs[0].token_ids
if tokenizer is None:
logger.warning("Tokenizer not found, dumping raw token ids")
dump_as_text = False
if dump_as_text:
input_content = tokenizer.decode(input_ids)
output_content = tokenizer.decode(output_ids)
else:
input_content = input_ids
output_content = output_ids
if dump_as_text:
with open(ids_path, "w") as f:
for item in results_list:
data = {
"task_id": task_id,
"input_text": input_content,
"output_text": output_content,
"input_lens": len(input_content),
"output_lens": len(output_content)
"task_id": item["task_id"],
"input_ids": item["input_ids"],
"output_ids": item["output_ids"],
"input_tokens": len(item["input_ids"]),
"output_tokens": len(item["output_ids"])
}
else:
data = {
"task_id": task_id,
"input_ids": input_ids,
"output_ids": output_ids,
"input_tokens": len(input_content),
"output_tokens": len(output_content)
}
f.write(json.dumps(data) + "\n")
f.write(json.dumps(data) + "\n")
logger.info(f"Dumped IDs to {ids_path}")
except Exception as e:
logger.warning(f"Failed to dump data to {dump_path}: {e}")
logger.warning(f"Failed to dump IDs to {ids_path}: {e}")
# Dump text if tokenizer available
if tokenizer is not None:
text_path = os.path.join(output_dir, "dumped_text.json")
try:
with open(text_path, "w") as f:
for item in results_list:
input_text = tokenizer.decode(item["input_ids"])
output_text = tokenizer.decode(item["output_ids"])
data = {
"task_id": item["task_id"],
"input_text": input_text,
"output_text": output_text,
"input_len": len(input_text),
"output_len": len(output_text)
}
f.write(json.dumps(data) + "\n")
logger.info(f"Dumped text to {text_path}")
except Exception as e:
logger.warning(f"Failed to dump text to {text_path}: {e}")
else:
logger.warning("Tokenizer not found, skipping text dump")

View File

@ -37,16 +37,14 @@ class JsonModeEval(Evaluator):
random_seed: int = 0,
apply_chat_template: bool = True,
system_prompt: Optional[str] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
output_dir: Optional[str] = None):
if not apply_chat_template:
raise ValueError(
f"{self.__class__.__name__} requires apply_chat_template=True.")
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
dump_path=dump_path,
dump_as_text=dump_as_text)
output_dir=output_dir)
if dataset_path is None:
dataset_path = "NousResearch/json-mode-eval"
self.data = datasets.load_dataset(dataset_path,
@ -124,20 +122,16 @@ class JsonModeEval(Evaluator):
type=int,
default=512,
help="Maximum generation length.")
@click.option("--dump_path",
@click.option("--output_dir",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
help="Directory to save results.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: Optional[str], num_samples: int,
random_seed: int, system_prompt: Optional[str],
max_input_length: int, max_output_length: int,
dump_path: Optional[str], dump_as_text: bool) -> None:
output_dir: Optional[str]) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(
max_tokens=max_output_length,
@ -147,7 +141,6 @@ class JsonModeEval(Evaluator):
random_seed=random_seed,
apply_chat_template=True,
system_prompt=system_prompt,
dump_path=dump_path,
dump_as_text=dump_as_text)
output_dir=output_dir)
evaluator.evaluate(llm, sampling_params)
llm.shutdown()

View File

@ -39,7 +39,8 @@ from ..inputs.utils import apply_chat_template as trtllm_apply_chat_template
from ..llmapi import RequestOutput
from ..logger import logger
from ..sampling_params import SamplingParams
from .interface import Evaluator, dump_inference_result, prepare_dump_path
from .interface import (Evaluator, collect_inference_result,
dump_inference_results)
# NOTE: lm_eval uses "<image>" as the default image placeholder
# https://github.com/EleutherAI/lm-evaluation-harness/blob/7f04db12d2f8e7a99a0830d99eb78130e1ba2122/lm_eval/models/hf_vlms.py#L25
@ -55,15 +56,14 @@ class LmEvalWrapper(TemplateLM):
chat_template_kwargs: Optional[dict[str, Any]] = None,
model_type: str | None = None,
is_force_single_image: bool = False,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
output_dir: Optional[str] = None):
super().__init__()
self.llm = llm
self.sampling_params = sampling_params
self.streaming = streaming
self.chat_template_kwargs = chat_template_kwargs
self.dump_path = dump_path
self.dump_as_text = dump_as_text
self.output_dir = output_dir
self.inference_results = []
@property
def eot_token_id(self) -> int:
@ -144,16 +144,19 @@ class LmEvalWrapper(TemplateLM):
outputs = []
task_id = 0
self.inference_results = []
for output in tqdm(results,
desc="Fetching responses",
disable=disable_tqdm):
res = output.result()
outputs.append(res)
dump_inference_result(self.dump_path, res, task_id,
self.dump_as_text,
getattr(self.llm, 'tokenizer', None))
collect_inference_result(self.inference_results, res, task_id)
task_id += 1
if self.output_dir:
dump_inference_results(self.output_dir, self.inference_results,
getattr(self.llm, 'tokenizer', None))
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
@ -178,8 +181,7 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
chat_template_kwargs: Optional[dict[str, Any]] = None,
model_type: str | None = None,
is_force_single_image: bool = False,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
output_dir: Optional[str] = None):
"""
Initialize the multimodal wrapper.
@ -189,14 +191,9 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
streaming: Whether to use streaming generation
max_images: Maximum number of images per prompt (currently unlimited in TRT-LLM), set to 999 from lm_eval's default value.
chat_template_kwargs: Chat template kwargs as JSON string
dump_path: Path to dump data to ids for trtllm-bench usage
dump_as_text: Whether to dump data to text
output_dir: Directory to save results
"""
super().__init__(llm,
sampling_params,
streaming,
dump_path=dump_path,
dump_as_text=dump_as_text)
super().__init__(llm, sampling_params, streaming, output_dir=output_dir)
# NOTE: Required by lm_eval to identify this as a multimodal model
self.MULTIMODAL = True
@ -329,14 +326,19 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
results.append(output)
outputs = []
self.inference_results = []
task_id = 0
for output in tqdm(results,
desc="Fetching responses",
disable=disable_tqdm):
res = output.result()
outputs.append(res)
dump_inference_result(self.dump_path, res, task_id,
self.dump_as_text,
getattr(self.llm, 'tokenizer', None))
collect_inference_result(self.inference_results, res, task_id)
task_id += 1
if self.output_dir:
dump_inference_results(self.output_dir, self.inference_results,
getattr(self.llm, 'tokenizer', None))
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
@ -358,8 +360,7 @@ class LmEvalEvaluator(Evaluator):
system_prompt: Optional[str] = None,
is_multimodal: bool = False,
chat_template_kwargs: Optional[dict[str, Any]] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
output_dir: Optional[str] = None):
try:
import lm_eval
except ImportError as e:
@ -379,8 +380,7 @@ class LmEvalEvaluator(Evaluator):
fewshot_as_multiturn=fewshot_as_multiturn,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs,
dump_path=dump_path,
dump_as_text=dump_as_text)
output_dir=output_dir)
self.task_name = task_name
self.dataset_path = dataset_path
self.num_samples = num_samples
@ -473,10 +473,6 @@ class LmEvalEvaluator(Evaluator):
import lm_eval
lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper
if self.dump_path:
self.dump_path = prepare_dump_path(self.dump_path)
logger.info(f"Dumping data to {self.dump_path}")
results = lm_eval.evaluate(
lm=lm_cls(llm,
sampling_params=sampling_params,
@ -484,8 +480,7 @@ class LmEvalEvaluator(Evaluator):
chat_template_kwargs=self.chat_template_kwargs,
model_type=model_type,
is_force_single_image=is_force_single_image,
dump_path=self.dump_path,
dump_as_text=self.dump_as_text),
output_dir=self.output_dir),
task_dict=self.task_dict,
limit=self.num_samples,
apply_chat_template=self.apply_chat_template,
@ -526,8 +521,7 @@ class LmEvalEvaluator(Evaluator):
is_multimodal=kwargs.pop("is_multimodal", False),
chat_template_kwargs=kwargs.pop("chat_template_kwargs",
None),
dump_path=kwargs.pop("dump_path", None),
dump_as_text=kwargs.pop("dump_as_text", False))
output_dir=kwargs.pop("output_dir", None))
sampling_params = SamplingParams(
max_tokens=kwargs.pop("max_output_length"),
truncate_prompt_tokens=kwargs.pop("max_input_length"),
@ -584,14 +578,10 @@ class GSM8K(LmEvalEvaluator):
type=int,
default=256,
help="Maximum generation length.")
@click.option("--dump_path",
@click.option("--output_dir",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -646,14 +636,10 @@ class GPQADiamond(LmEvalEvaluator):
type=int,
default=32768,
help="Maximum generation length.")
@click.option("--dump_path",
@click.option("--output_dir",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -704,14 +690,10 @@ class GPQAMain(LmEvalEvaluator):
type=int,
default=32768,
help="Maximum generation length.")
@click.option("--dump_path",
@click.option("--output_dir",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -762,14 +744,10 @@ class GPQAExtended(LmEvalEvaluator):
type=int,
default=32768,
help="Maximum generation length.")
@click.option("--dump_path",
@click.option("--output_dir",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -821,14 +799,10 @@ class MMMU(LmEvalEvaluator):
default=
512, # NOTE: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/mmmu/_template_yaml#L13
help="Maximum generation length.")
@click.option("--dump_path",
@click.option("--output_dir",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:

View File

@ -69,9 +69,7 @@ class LongBenchV2(Evaluator):
random_seed: int = 0,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
chat_template_kwargs: Optional[dict[str, Any]] = None):
"""Initialize LongBench v2 evaluator.
Args:
@ -93,15 +91,12 @@ class LongBenchV2(Evaluator):
apply_chat_template: Whether to apply model's chat template
system_prompt: System prompt to prepend
chat_template_kwargs: Chat template kwargs as JSON string
dump_path: Path to dump data to ids for trtllm-bench usage.
dump_as_text: Whether to dump data to text.
"""
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs,
dump_path=dump_path,
dump_as_text=dump_as_text)
output_dir=output_dir)
self.dataset_path = dataset_path
self.num_samples = num_samples
@ -769,14 +764,6 @@ class LongBenchV2(Evaluator):
type=float,
default=0.95,
help="Top p for sampling.")
@click.option("--dump_path",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: str, prompts_dir: Optional[str],
@ -787,8 +774,7 @@ class LongBenchV2(Evaluator):
apply_chat_template: bool, system_prompt: Optional[str],
max_len: int, max_input_length: int, max_output_length: int,
chat_template_kwargs: Optional[dict[str, Any]],
temperature: float, top_p: float,
dump_path: Optional[str], dump_as_text: bool) -> None:
temperature: float, top_p: float) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(max_tokens=max_output_length,
@ -812,9 +798,7 @@ class LongBenchV2(Evaluator):
random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs,
dump_path=dump_path,
dump_as_text=dump_as_text)
chat_template_kwargs=chat_template_kwargs)
evaluator.evaluate(llm, sampling_params)
llm.shutdown()

View File

@ -122,14 +122,12 @@ class MMLU(Evaluator):
apply_chat_template: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
dump_path: Optional[str] = None,
dump_as_text: bool = False):
output_dir: Optional[str] = None):
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs,
dump_path=dump_path,
dump_as_text=dump_as_text)
output_dir=output_dir)
if dataset_path is None:
dataset_path = self.dowload_dataset()
self.dataset_path = dataset_path
@ -306,14 +304,10 @@ class MMLU(Evaluator):
help="Maximum generation length.")
@click.option("--check_accuracy", is_flag=True, default=False)
@click.option("--accuracy_threshold", type=float, default=30)
@click.option("--dump_path",
@click.option("--output_dir",
type=str,
default=None,
help="Path to dump data to ids for trtllm-bench usage.")
@click.option("--dump_as_text",
is_flag=True,
default=False,
help="Whether to dump data to text.")
help="Directory to save results.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: Optional[str], num_samples: int,
@ -321,8 +315,7 @@ class MMLU(Evaluator):
chat_template_kwargs: Optional[dict[str, Any]],
system_prompt: Optional[str], max_input_length: int,
max_output_length: int, check_accuracy: bool,
accuracy_threshold: float, dump_path: Optional[str],
dump_as_text: bool) -> None:
accuracy_threshold: float, output_dir: Optional[str]) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(
max_tokens=max_output_length,
@ -334,8 +327,7 @@ class MMLU(Evaluator):
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs,
dump_path=dump_path,
dump_as_text=dump_as_text)
output_dir=output_dir)
accuracy = evaluator.evaluate(llm, sampling_params)
llm.shutdown()