mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Use output_dir and save both of prompt ids and text
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
This commit is contained in:
parent
cbc67b7c76
commit
c9e518cd24
@ -35,13 +35,11 @@ class CnnDailymail(Evaluator):
|
||||
rouge_path: Optional[str] = None,
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
output_dir: Optional[str] = None):
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
output_dir=output_dir)
|
||||
if dataset_path is None:
|
||||
dataset_path = "ccdv/cnn_dailymail"
|
||||
self.data = datasets.load_dataset(dataset_path,
|
||||
@ -115,21 +113,17 @@ class CnnDailymail(Evaluator):
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
help="Directory to save results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: Optional[str], num_samples: int,
|
||||
random_seed: int, rouge_path: Optional[str],
|
||||
apply_chat_template: bool, system_prompt: Optional[str],
|
||||
max_input_length: int, max_output_length: int,
|
||||
dump_path: Optional[str], dump_as_text: bool) -> None:
|
||||
output_dir: Optional[str]) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -140,7 +134,6 @@ class CnnDailymail(Evaluator):
|
||||
rouge_path=rouge_path,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
output_dir=output_dir)
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
@ -38,8 +38,7 @@ class Evaluator(ABC):
|
||||
fewshot_as_multiturn: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
output_dir: Optional[str] = None):
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
@ -47,8 +46,8 @@ class Evaluator(ABC):
|
||||
self.fewshot_as_multiturn = fewshot_as_multiturn
|
||||
self.system_prompt = system_prompt
|
||||
self.chat_template_kwargs = chat_template_kwargs
|
||||
self.dump_path = dump_path
|
||||
self.dump_as_text = dump_as_text
|
||||
self.output_dir = output_dir
|
||||
self.inference_results = []
|
||||
|
||||
@abstractmethod
|
||||
def generate_samples(self) -> Iterable[tuple]:
|
||||
@ -110,17 +109,17 @@ class Evaluator(ABC):
|
||||
auxiliaries.append(aux)
|
||||
results = []
|
||||
task_id = 0
|
||||
if self.dump_path:
|
||||
self.dump_path = prepare_dump_path(self.dump_path)
|
||||
logger.info(f"Dumping data to {self.dump_path}")
|
||||
self.inference_results = []
|
||||
for output in tqdm(outputs, desc="Fetching responses"):
|
||||
res = output.result()
|
||||
results.append(res)
|
||||
dump_inference_result(self.dump_path, res, task_id,
|
||||
self.dump_as_text,
|
||||
getattr(llm, 'tokenizer', None))
|
||||
collect_inference_result(self.inference_results, res, task_id)
|
||||
task_id += 1
|
||||
|
||||
if self.output_dir:
|
||||
dump_inference_results(self.output_dir, self.inference_results,
|
||||
getattr(llm, 'tokenizer', None))
|
||||
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
|
||||
@ -134,52 +133,59 @@ class Evaluator(ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def prepare_dump_path(dump_path: str) -> str:
|
||||
if dump_path:
|
||||
if os.path.isdir(dump_path) or dump_path.endswith(os.sep):
|
||||
dump_path = os.path.join(dump_path, "dumped_data.json")
|
||||
os.makedirs(os.path.dirname(dump_path), exist_ok=True)
|
||||
if os.path.exists(dump_path):
|
||||
os.remove(dump_path)
|
||||
return dump_path
|
||||
def collect_inference_result(results_list: List[dict], result: RequestOutput,
|
||||
task_id: int):
|
||||
input_ids = result.prompt_token_ids
|
||||
output_ids = result.outputs[0].token_ids
|
||||
results_list.append({
|
||||
"task_id": task_id,
|
||||
"input_ids": input_ids,
|
||||
"output_ids": output_ids
|
||||
})
|
||||
|
||||
|
||||
def dump_inference_result(dump_path: str, result: RequestOutput, task_id: int,
|
||||
dump_as_text: bool, tokenizer: Any):
|
||||
if not dump_path:
|
||||
def dump_inference_results(output_dir: str, results_list: List[dict],
|
||||
tokenizer: Any):
|
||||
if not output_dir:
|
||||
return
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Dump token ids
|
||||
ids_path = os.path.join(output_dir, "dumped_ids.json")
|
||||
try:
|
||||
with open(dump_path, "a") as f:
|
||||
input_ids = result.prompt_token_ids
|
||||
output_ids = result.outputs[0].token_ids
|
||||
|
||||
if tokenizer is None:
|
||||
logger.warning("Tokenizer not found, dumping raw token ids")
|
||||
dump_as_text = False
|
||||
|
||||
if dump_as_text:
|
||||
input_content = tokenizer.decode(input_ids)
|
||||
output_content = tokenizer.decode(output_ids)
|
||||
else:
|
||||
input_content = input_ids
|
||||
output_content = output_ids
|
||||
|
||||
if dump_as_text:
|
||||
with open(ids_path, "w") as f:
|
||||
for item in results_list:
|
||||
data = {
|
||||
"task_id": task_id,
|
||||
"input_text": input_content,
|
||||
"output_text": output_content,
|
||||
"input_lens": len(input_content),
|
||||
"output_lens": len(output_content)
|
||||
"task_id": item["task_id"],
|
||||
"input_ids": item["input_ids"],
|
||||
"output_ids": item["output_ids"],
|
||||
"input_tokens": len(item["input_ids"]),
|
||||
"output_tokens": len(item["output_ids"])
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"task_id": task_id,
|
||||
"input_ids": input_ids,
|
||||
"output_ids": output_ids,
|
||||
"input_tokens": len(input_content),
|
||||
"output_tokens": len(output_content)
|
||||
}
|
||||
f.write(json.dumps(data) + "\n")
|
||||
f.write(json.dumps(data) + "\n")
|
||||
logger.info(f"Dumped IDs to {ids_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to dump data to {dump_path}: {e}")
|
||||
logger.warning(f"Failed to dump IDs to {ids_path}: {e}")
|
||||
|
||||
# Dump text if tokenizer available
|
||||
if tokenizer is not None:
|
||||
text_path = os.path.join(output_dir, "dumped_text.json")
|
||||
try:
|
||||
with open(text_path, "w") as f:
|
||||
for item in results_list:
|
||||
input_text = tokenizer.decode(item["input_ids"])
|
||||
output_text = tokenizer.decode(item["output_ids"])
|
||||
data = {
|
||||
"task_id": item["task_id"],
|
||||
"input_text": input_text,
|
||||
"output_text": output_text,
|
||||
"input_len": len(input_text),
|
||||
"output_len": len(output_text)
|
||||
}
|
||||
f.write(json.dumps(data) + "\n")
|
||||
logger.info(f"Dumped text to {text_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to dump text to {text_path}: {e}")
|
||||
else:
|
||||
logger.warning("Tokenizer not found, skipping text dump")
|
||||
|
||||
@ -37,16 +37,14 @@ class JsonModeEval(Evaluator):
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = True,
|
||||
system_prompt: Optional[str] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
output_dir: Optional[str] = None):
|
||||
if not apply_chat_template:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires apply_chat_template=True.")
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
output_dir=output_dir)
|
||||
if dataset_path is None:
|
||||
dataset_path = "NousResearch/json-mode-eval"
|
||||
self.data = datasets.load_dataset(dataset_path,
|
||||
@ -124,20 +122,16 @@ class JsonModeEval(Evaluator):
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
help="Directory to save results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: Optional[str], num_samples: int,
|
||||
random_seed: int, system_prompt: Optional[str],
|
||||
max_input_length: int, max_output_length: int,
|
||||
dump_path: Optional[str], dump_as_text: bool) -> None:
|
||||
output_dir: Optional[str]) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -147,7 +141,6 @@ class JsonModeEval(Evaluator):
|
||||
random_seed=random_seed,
|
||||
apply_chat_template=True,
|
||||
system_prompt=system_prompt,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
output_dir=output_dir)
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
@ -39,7 +39,8 @@ from ..inputs.utils import apply_chat_template as trtllm_apply_chat_template
|
||||
from ..llmapi import RequestOutput
|
||||
from ..logger import logger
|
||||
from ..sampling_params import SamplingParams
|
||||
from .interface import Evaluator, dump_inference_result, prepare_dump_path
|
||||
from .interface import (Evaluator, collect_inference_result,
|
||||
dump_inference_results)
|
||||
|
||||
# NOTE: lm_eval uses "<image>" as the default image placeholder
|
||||
# https://github.com/EleutherAI/lm-evaluation-harness/blob/7f04db12d2f8e7a99a0830d99eb78130e1ba2122/lm_eval/models/hf_vlms.py#L25
|
||||
@ -55,15 +56,14 @@ class LmEvalWrapper(TemplateLM):
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
model_type: str | None = None,
|
||||
is_force_single_image: bool = False,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
output_dir: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.llm = llm
|
||||
self.sampling_params = sampling_params
|
||||
self.streaming = streaming
|
||||
self.chat_template_kwargs = chat_template_kwargs
|
||||
self.dump_path = dump_path
|
||||
self.dump_as_text = dump_as_text
|
||||
self.output_dir = output_dir
|
||||
self.inference_results = []
|
||||
|
||||
@property
|
||||
def eot_token_id(self) -> int:
|
||||
@ -144,16 +144,19 @@ class LmEvalWrapper(TemplateLM):
|
||||
|
||||
outputs = []
|
||||
task_id = 0
|
||||
self.inference_results = []
|
||||
for output in tqdm(results,
|
||||
desc="Fetching responses",
|
||||
disable=disable_tqdm):
|
||||
res = output.result()
|
||||
outputs.append(res)
|
||||
dump_inference_result(self.dump_path, res, task_id,
|
||||
self.dump_as_text,
|
||||
getattr(self.llm, 'tokenizer', None))
|
||||
collect_inference_result(self.inference_results, res, task_id)
|
||||
task_id += 1
|
||||
|
||||
if self.output_dir:
|
||||
dump_inference_results(self.output_dir, self.inference_results,
|
||||
getattr(self.llm, 'tokenizer', None))
|
||||
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
|
||||
@ -178,8 +181,7 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
model_type: str | None = None,
|
||||
is_force_single_image: bool = False,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
output_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the multimodal wrapper.
|
||||
|
||||
@ -189,14 +191,9 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
|
||||
streaming: Whether to use streaming generation
|
||||
max_images: Maximum number of images per prompt (currently unlimited in TRT-LLM), set to 999 from lm_eval's default value.
|
||||
chat_template_kwargs: Chat template kwargs as JSON string
|
||||
dump_path: Path to dump data to ids for trtllm-bench usage
|
||||
dump_as_text: Whether to dump data to text
|
||||
output_dir: Directory to save results
|
||||
"""
|
||||
super().__init__(llm,
|
||||
sampling_params,
|
||||
streaming,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
super().__init__(llm, sampling_params, streaming, output_dir=output_dir)
|
||||
|
||||
# NOTE: Required by lm_eval to identify this as a multimodal model
|
||||
self.MULTIMODAL = True
|
||||
@ -329,14 +326,19 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
|
||||
results.append(output)
|
||||
|
||||
outputs = []
|
||||
self.inference_results = []
|
||||
task_id = 0
|
||||
for output in tqdm(results,
|
||||
desc="Fetching responses",
|
||||
disable=disable_tqdm):
|
||||
res = output.result()
|
||||
outputs.append(res)
|
||||
dump_inference_result(self.dump_path, res, task_id,
|
||||
self.dump_as_text,
|
||||
getattr(self.llm, 'tokenizer', None))
|
||||
collect_inference_result(self.inference_results, res, task_id)
|
||||
task_id += 1
|
||||
|
||||
if self.output_dir:
|
||||
dump_inference_results(self.output_dir, self.inference_results,
|
||||
getattr(self.llm, 'tokenizer', None))
|
||||
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
@ -358,8 +360,7 @@ class LmEvalEvaluator(Evaluator):
|
||||
system_prompt: Optional[str] = None,
|
||||
is_multimodal: bool = False,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
output_dir: Optional[str] = None):
|
||||
try:
|
||||
import lm_eval
|
||||
except ImportError as e:
|
||||
@ -379,8 +380,7 @@ class LmEvalEvaluator(Evaluator):
|
||||
fewshot_as_multiturn=fewshot_as_multiturn,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
output_dir=output_dir)
|
||||
self.task_name = task_name
|
||||
self.dataset_path = dataset_path
|
||||
self.num_samples = num_samples
|
||||
@ -473,10 +473,6 @@ class LmEvalEvaluator(Evaluator):
|
||||
import lm_eval
|
||||
lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper
|
||||
|
||||
if self.dump_path:
|
||||
self.dump_path = prepare_dump_path(self.dump_path)
|
||||
logger.info(f"Dumping data to {self.dump_path}")
|
||||
|
||||
results = lm_eval.evaluate(
|
||||
lm=lm_cls(llm,
|
||||
sampling_params=sampling_params,
|
||||
@ -484,8 +480,7 @@ class LmEvalEvaluator(Evaluator):
|
||||
chat_template_kwargs=self.chat_template_kwargs,
|
||||
model_type=model_type,
|
||||
is_force_single_image=is_force_single_image,
|
||||
dump_path=self.dump_path,
|
||||
dump_as_text=self.dump_as_text),
|
||||
output_dir=self.output_dir),
|
||||
task_dict=self.task_dict,
|
||||
limit=self.num_samples,
|
||||
apply_chat_template=self.apply_chat_template,
|
||||
@ -526,8 +521,7 @@ class LmEvalEvaluator(Evaluator):
|
||||
is_multimodal=kwargs.pop("is_multimodal", False),
|
||||
chat_template_kwargs=kwargs.pop("chat_template_kwargs",
|
||||
None),
|
||||
dump_path=kwargs.pop("dump_path", None),
|
||||
dump_as_text=kwargs.pop("dump_as_text", False))
|
||||
output_dir=kwargs.pop("output_dir", None))
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=kwargs.pop("max_output_length"),
|
||||
truncate_prompt_tokens=kwargs.pop("max_input_length"),
|
||||
@ -584,14 +578,10 @@ class GSM8K(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -646,14 +636,10 @@ class GPQADiamond(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -704,14 +690,10 @@ class GPQAMain(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -762,14 +744,10 @@ class GPQAExtended(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -821,14 +799,10 @@ class MMMU(LmEvalEvaluator):
|
||||
default=
|
||||
512, # NOTE: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/mmmu/_template_yaml#L13
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
|
||||
@ -69,9 +69,7 @@ class LongBenchV2(Evaluator):
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None):
|
||||
"""Initialize LongBench v2 evaluator.
|
||||
|
||||
Args:
|
||||
@ -93,15 +91,12 @@ class LongBenchV2(Evaluator):
|
||||
apply_chat_template: Whether to apply model's chat template
|
||||
system_prompt: System prompt to prepend
|
||||
chat_template_kwargs: Chat template kwargs as JSON string
|
||||
dump_path: Path to dump data to ids for trtllm-bench usage.
|
||||
dump_as_text: Whether to dump data to text.
|
||||
"""
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
output_dir=output_dir)
|
||||
|
||||
self.dataset_path = dataset_path
|
||||
self.num_samples = num_samples
|
||||
@ -769,14 +764,6 @@ class LongBenchV2(Evaluator):
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="Top p for sampling.")
|
||||
@click.option("--dump_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: str, prompts_dir: Optional[str],
|
||||
@ -787,8 +774,7 @@ class LongBenchV2(Evaluator):
|
||||
apply_chat_template: bool, system_prompt: Optional[str],
|
||||
max_len: int, max_input_length: int, max_output_length: int,
|
||||
chat_template_kwargs: Optional[dict[str, Any]],
|
||||
temperature: float, top_p: float,
|
||||
dump_path: Optional[str], dump_as_text: bool) -> None:
|
||||
temperature: float, top_p: float) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=max_output_length,
|
||||
@ -812,9 +798,7 @@ class LongBenchV2(Evaluator):
|
||||
random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
@ -122,14 +122,12 @@ class MMLU(Evaluator):
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
output_dir: Optional[str] = None):
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
output_dir=output_dir)
|
||||
if dataset_path is None:
|
||||
dataset_path = self.dowload_dataset()
|
||||
self.dataset_path = dataset_path
|
||||
@ -306,14 +304,10 @@ class MMLU(Evaluator):
|
||||
help="Maximum generation length.")
|
||||
@click.option("--check_accuracy", is_flag=True, default=False)
|
||||
@click.option("--accuracy_threshold", type=float, default=30)
|
||||
@click.option("--dump_path",
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
help="Directory to save results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: Optional[str], num_samples: int,
|
||||
@ -321,8 +315,7 @@ class MMLU(Evaluator):
|
||||
chat_template_kwargs: Optional[dict[str, Any]],
|
||||
system_prompt: Optional[str], max_input_length: int,
|
||||
max_output_length: int, check_accuracy: bool,
|
||||
accuracy_threshold: float, dump_path: Optional[str],
|
||||
dump_as_text: bool) -> None:
|
||||
accuracy_threshold: float, output_dir: Optional[str]) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -334,8 +327,7 @@ class MMLU(Evaluator):
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
output_dir=output_dir)
|
||||
accuracy = evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user