This commit is contained in:
heyuhhh 2026-01-13 21:15:46 +08:00 committed by GitHub
commit 5f07c4e5e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 175 additions and 28 deletions

View File

@ -34,10 +34,12 @@ class CnnDailymail(Evaluator):
random_seed: int = 0,
rouge_path: Optional[str] = None,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None):
system_prompt: Optional[str] = None,
output_dir: Optional[str] = None):
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
system_prompt=system_prompt,
output_dir=output_dir)
if dataset_path is None:
dataset_path = "ccdv/cnn_dailymail"
self.data = datasets.load_dataset(dataset_path,
@ -111,12 +113,17 @@ class CnnDailymail(Evaluator):
type=int,
default=100,
help="Maximum generation length.")
@click.option("--output_dir",
type=str,
default=None,
help="Directory to save results.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: Optional[str], num_samples: int,
random_seed: int, rouge_path: Optional[str],
apply_chat_template: bool, system_prompt: Optional[str],
max_input_length: int, max_output_length: int) -> None:
max_input_length: int, max_output_length: int,
output_dir: Optional[str]) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(
max_tokens=max_output_length,
@ -126,6 +133,7 @@ class CnnDailymail(Evaluator):
random_seed=random_seed,
rouge_path=rouge_path,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
system_prompt=system_prompt,
output_dir=output_dir)
evaluator.evaluate(llm, sampling_params)
llm.shutdown()

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
import random
from abc import ABC, abstractmethod
from typing import Any, Iterable, List, Optional, Union
@ -35,7 +37,8 @@ class Evaluator(ABC):
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None):
chat_template_kwargs: Optional[dict[str, Any]] = None,
output_dir: Optional[str] = None):
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
@ -43,6 +46,7 @@ class Evaluator(ABC):
self.fewshot_as_multiturn = fewshot_as_multiturn
self.system_prompt = system_prompt
self.chat_template_kwargs = chat_template_kwargs
self.output_dir = output_dir
@abstractmethod
def generate_samples(self) -> Iterable[tuple]:
@ -105,6 +109,11 @@ class Evaluator(ABC):
results = []
for output in tqdm(outputs, desc="Fetching responses"):
results.append(output.result())
if self.output_dir:
dump_inference_results(self.output_dir, results,
getattr(llm, 'tokenizer', None))
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
@ -116,3 +125,60 @@ class Evaluator(ABC):
@staticmethod
def command(ctx, *args, **kwargs) -> None:
raise NotImplementedError()
def dump_inference_results(output_dir: str, results: List[dict],
tokenizer: Any):
if not output_dir:
return
os.makedirs(output_dir, exist_ok=True)
# Collect results
results_list = []
for task_id, result in enumerate(results):
output_ids = result.outputs[0].token_ids
output_text = result.outputs[0].text.strip()
input_text = result.prompt.strip()
input_ids = tokenizer.encode(input_text)
results_list.append({
"task_id": task_id,
"input_ids": input_ids,
"output_ids": output_ids,
"input_text": input_text,
"output_text": output_text
})
# Dump token ids
ids_path = os.path.join(output_dir, "dumped_ids.json")
try:
with open(ids_path, "w") as f:
for item in results_list:
data = {
"task_id": item["task_id"],
"input_ids": item["input_ids"],
"output_ids": item["output_ids"],
"input_tokens": len(item["input_ids"]),
"output_tokens": len(item["output_ids"])
}
f.write(json.dumps(data) + "\n")
logger.info(f"Dumped IDs to {ids_path}")
except Exception as e:
logger.warning(f"Failed to dump IDs to {ids_path}: {e}")
# Dump text
text_path = os.path.join(output_dir, "dumped_text.json")
try:
with open(text_path, "w") as f:
for item in results_list:
data = {
"task_id": item["task_id"],
"input_text": item["input_text"],
"output_text": item["output_text"],
"input_len": len(item["input_text"]),
"output_len": len(item["output_text"])
}
f.write(json.dumps(data) + "\n")
logger.info(f"Dumped text to {text_path}")
except Exception as e:
logger.warning(f"Failed to dump text to {text_path}: {e}")

View File

@ -36,13 +36,15 @@ class JsonModeEval(Evaluator):
num_samples: Optional[int] = None,
random_seed: int = 0,
apply_chat_template: bool = True,
system_prompt: Optional[str] = None):
system_prompt: Optional[str] = None,
output_dir: Optional[str] = None):
if not apply_chat_template:
raise ValueError(
f"{self.__class__.__name__} requires apply_chat_template=True.")
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
system_prompt=system_prompt,
output_dir=output_dir)
if dataset_path is None:
dataset_path = "NousResearch/json-mode-eval"
self.data = datasets.load_dataset(dataset_path,
@ -120,11 +122,16 @@ class JsonModeEval(Evaluator):
type=int,
default=512,
help="Maximum generation length.")
@click.option("--output_dir",
type=str,
default=None,
help="Directory to save results.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: Optional[str], num_samples: int,
random_seed: int, system_prompt: Optional[str],
max_input_length: int, max_output_length: int) -> None:
max_input_length: int, max_output_length: int,
output_dir: Optional[str]) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(
max_tokens=max_output_length,
@ -133,6 +140,7 @@ class JsonModeEval(Evaluator):
num_samples=num_samples,
random_seed=random_seed,
apply_chat_template=True,
system_prompt=system_prompt)
system_prompt=system_prompt,
output_dir=output_dir)
evaluator.evaluate(llm, sampling_params)
llm.shutdown()

View File

@ -39,7 +39,7 @@ from ..inputs.utils import apply_chat_template as trtllm_apply_chat_template
from ..llmapi import RequestOutput
from ..logger import logger
from ..sampling_params import SamplingParams
from .interface import Evaluator
from .interface import Evaluator, dump_inference_results
# NOTE: lm_eval uses "<image>" as the default image placeholder
# https://github.com/EleutherAI/lm-evaluation-harness/blob/7f04db12d2f8e7a99a0830d99eb78130e1ba2122/lm_eval/models/hf_vlms.py#L25
@ -54,12 +54,14 @@ class LmEvalWrapper(TemplateLM):
streaming: bool = False,
chat_template_kwargs: Optional[dict[str, Any]] = None,
model_type: str | None = None,
is_force_single_image: bool = False):
is_force_single_image: bool = False,
output_dir: Optional[str] = None):
super().__init__()
self.llm = llm
self.sampling_params = sampling_params
self.streaming = streaming
self.chat_template_kwargs = chat_template_kwargs
self.output_dir = output_dir
@property
def eot_token_id(self) -> int:
@ -144,6 +146,10 @@ class LmEvalWrapper(TemplateLM):
disable=disable_tqdm):
outputs.append(output.result())
if self.output_dir:
dump_inference_results(self.output_dir, outputs,
getattr(self.llm, 'tokenizer', None))
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
@ -167,7 +173,8 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
max_images: int = 999,
chat_template_kwargs: Optional[dict[str, Any]] = None,
model_type: str | None = None,
is_force_single_image: bool = False):
is_force_single_image: bool = False,
output_dir: Optional[str] = None):
"""
Initialize the multimodal wrapper.
@ -176,8 +183,10 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
sampling_params: Parameters for text generation
streaming: Whether to use streaming generation
max_images: Maximum number of images per prompt (currently unlimited in TRT-LLM), set to 999 from lm_eval's default value.
chat_template_kwargs: Chat template kwargs as JSON string
output_dir: Directory to save results
"""
super().__init__(llm, sampling_params, streaming)
super().__init__(llm, sampling_params, streaming, output_dir=output_dir)
# NOTE: Required by lm_eval to identify this as a multimodal model
self.MULTIMODAL = True
@ -315,6 +324,10 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
disable=disable_tqdm):
outputs.append(output.result())
if self.output_dir:
dump_inference_results(self.output_dir, outputs,
getattr(self.llm, 'tokenizer', None))
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
@ -334,7 +347,8 @@ class LmEvalEvaluator(Evaluator):
fewshot_as_multiturn: bool = False,
system_prompt: Optional[str] = None,
is_multimodal: bool = False,
chat_template_kwargs: Optional[dict[str, Any]] = None):
chat_template_kwargs: Optional[dict[str, Any]] = None,
output_dir: Optional[str] = None):
try:
import lm_eval
except ImportError as e:
@ -353,7 +367,8 @@ class LmEvalEvaluator(Evaluator):
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
chat_template_kwargs=chat_template_kwargs,
output_dir=output_dir)
self.task_name = task_name
self.dataset_path = dataset_path
self.num_samples = num_samples
@ -445,13 +460,15 @@ class LmEvalEvaluator(Evaluator):
is_force_single_image: bool = False) -> float:
import lm_eval
lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper
results = lm_eval.evaluate(
lm=lm_cls(llm,
sampling_params=sampling_params,
streaming=streaming,
chat_template_kwargs=self.chat_template_kwargs,
model_type=model_type,
is_force_single_image=is_force_single_image),
is_force_single_image=is_force_single_image,
output_dir=self.output_dir),
task_dict=self.task_dict,
limit=self.num_samples,
apply_chat_template=self.apply_chat_template,
@ -491,7 +508,8 @@ class LmEvalEvaluator(Evaluator):
system_prompt=kwargs.pop("system_prompt", None),
is_multimodal=kwargs.pop("is_multimodal", False),
chat_template_kwargs=kwargs.pop("chat_template_kwargs",
None))
None),
output_dir=kwargs.pop("output_dir", None))
sampling_params = SamplingParams(
max_tokens=kwargs.pop("max_output_length"),
truncate_prompt_tokens=kwargs.pop("max_input_length"),
@ -548,6 +566,10 @@ class GSM8K(LmEvalEvaluator):
type=int,
default=256,
help="Maximum generation length.")
@click.option("--output_dir",
type=str,
default=None,
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -602,6 +624,10 @@ class GPQADiamond(LmEvalEvaluator):
type=int,
default=32768,
help="Maximum generation length.")
@click.option("--output_dir",
type=str,
default=None,
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -652,6 +678,10 @@ class GPQAMain(LmEvalEvaluator):
type=int,
default=32768,
help="Maximum generation length.")
@click.option("--output_dir",
type=str,
default=None,
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -702,6 +732,10 @@ class GPQAExtended(LmEvalEvaluator):
type=int,
default=32768,
help="Maximum generation length.")
@click.option("--output_dir",
type=str,
default=None,
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -753,6 +787,10 @@ class MMMU(LmEvalEvaluator):
default=
512, # NOTE: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/mmmu/_template_yaml#L13
help="Maximum generation length.")
@click.option("--output_dir",
type=str,
default=None,
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -897,6 +935,10 @@ class LongBenchV1(LmEvalEvaluator):
type=str,
default=None,
help="System prompt.")
@click.option("--output_dir",
type=str,
default=None,
help="Directory to save the results.")
@click.pass_context
@staticmethod
def command(ctx, **kwargs) -> None:
@ -908,7 +950,8 @@ class LongBenchV1(LmEvalEvaluator):
random_seed=kwargs.pop("random_seed", 0),
apply_chat_template=kwargs.pop("apply_chat_template", True),
system_prompt=kwargs.pop("system_prompt", None),
chat_template_kwargs=kwargs.pop("chat_template_kwargs", None))
chat_template_kwargs=kwargs.pop("chat_template_kwargs", None),
output_dir=kwargs.pop("output_dir", None))
# Let lm-eval task configs control sampling via gen_kwargs.
sampling_params = None

View File

@ -62,7 +62,9 @@ class LongBenchV2(Evaluator):
cot: bool = False,
no_context: bool = False,
rag: int = 0,
max_len: int = 128000,
max_input_length: int = 128000,
max_output_length: int = 32000,
output_dir: Optional[str] = None,
random_seed: int = 0,
apply_chat_template: bool = False,
@ -81,7 +83,9 @@ class LongBenchV2(Evaluator):
cot: Enable Chain-of-Thought reasoning
no_context: Test without context (memorization test)
rag: Number of top retrieved contexts to use (0 to disable)
max_input_length: Maximum prompt length in tokens for truncation
max_len: Maximum length (input + output) in tokens
max_input_length: Maximum context length in tokens. If exceeds, the prompt will be truncated in the middle.
max_output_length: Maximum output length in tokens for truncation
output_dir: Directory to save evaluation results
random_seed: Random seed for reproducibility
apply_chat_template: Whether to apply model's chat template
@ -91,7 +95,8 @@ class LongBenchV2(Evaluator):
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
chat_template_kwargs=chat_template_kwargs,
output_dir=output_dir)
self.dataset_path = dataset_path
self.num_samples = num_samples
@ -103,7 +108,9 @@ class LongBenchV2(Evaluator):
self.no_context = no_context
self.rag = rag
self.output_dir = output_dir
self.max_input_length = max_input_length
# We need to minus max_output_length from max_len to reserve budget for output tokens.
self.max_input_length = min(max_input_length,
max_len - max_output_length)
# Will be set during evaluation
self.tokenizer = None
@ -305,7 +312,6 @@ class LongBenchV2(Evaluator):
If the prompt exceeds max_input_length, it takes the first half and last half
to preserve both context beginning and end.
We need to minus max_output_length from max_len to reserve budget for output tokens.
Args:
prompt: The prompt string to truncate
@ -727,12 +733,19 @@ class LongBenchV2(Evaluator):
type=str,
default=None,
help="System prompt.")
@click.option(
"--max_len",
type=int,
default=1024000,
help=
"Maximum length (input + output) in tokens which can be supported by the model."
)
@click.option(
"--max_input_length",
type=int,
default=128000,
help=
"Maximum prompt length before apply chat template. If exceeds, the prompt will be truncated in the middle."
"Maximum context length in tokens. If exceeds, the prompt will be truncated in the middle."
)
@click.option("--max_output_length",
type=int,
@ -763,7 +776,7 @@ class LongBenchV2(Evaluator):
cot: bool, no_context: bool, rag: int,
output_dir: Optional[str], random_seed: int,
apply_chat_template: bool, system_prompt: Optional[str],
max_input_length: int, max_output_length: int,
max_len: int, max_input_length: int, max_output_length: int,
chat_template_kwargs: Optional[dict[str, Any]],
temperature: float, top_p: float) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
@ -782,7 +795,9 @@ class LongBenchV2(Evaluator):
cot=cot,
no_context=no_context,
rag=rag,
max_len=max_len,
max_input_length=max_input_length,
max_output_length=max_output_length,
output_dir=output_dir,
random_seed=random_seed,
apply_chat_template=apply_chat_template,

View File

@ -121,11 +121,13 @@ class MMLU(Evaluator):
random_seed: int = 0,
apply_chat_template: bool = False,
system_prompt: Optional[str] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None):
chat_template_kwargs: Optional[dict[str, Any]] = None,
output_dir: Optional[str] = None):
super().__init__(random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
chat_template_kwargs=chat_template_kwargs,
output_dir=output_dir)
if dataset_path is None:
dataset_path = self.dowload_dataset()
self.dataset_path = dataset_path
@ -302,6 +304,10 @@ class MMLU(Evaluator):
help="Maximum generation length.")
@click.option("--check_accuracy", is_flag=True, default=False)
@click.option("--accuracy_threshold", type=float, default=30)
@click.option("--output_dir",
type=str,
default=None,
help="Directory to save results.")
@click.pass_context
@staticmethod
def command(ctx, dataset_path: Optional[str], num_samples: int,
@ -309,7 +315,7 @@ class MMLU(Evaluator):
chat_template_kwargs: Optional[dict[str, Any]],
system_prompt: Optional[str], max_input_length: int,
max_output_length: int, check_accuracy: bool,
accuracy_threshold: float) -> None:
accuracy_threshold: float, output_dir: Optional[str]) -> None:
llm: Union[LLM, PyTorchLLM] = ctx.obj
sampling_params = SamplingParams(
max_tokens=max_output_length,
@ -320,7 +326,8 @@ class MMLU(Evaluator):
random_seed=random_seed,
apply_chat_template=apply_chat_template,
system_prompt=system_prompt,
chat_template_kwargs=chat_template_kwargs)
chat_template_kwargs=chat_template_kwargs,
output_dir=output_dir)
accuracy = evaluator.evaluate(llm, sampling_params)
llm.shutdown()