mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Support to export data in trtllm-eval
Signed-off-by: yuhangh <58161490+heyuhhh@users.noreply.github.com>
This commit is contained in:
parent
6ab996d635
commit
cbc67b7c76
@ -34,10 +34,14 @@ class CnnDailymail(Evaluator):
|
||||
random_seed: int = 0,
|
||||
rouge_path: Optional[str] = None,
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None):
|
||||
system_prompt: Optional[str] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
system_prompt=system_prompt,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
if dataset_path is None:
|
||||
dataset_path = "ccdv/cnn_dailymail"
|
||||
self.data = datasets.load_dataset(dataset_path,
|
||||
@ -111,12 +115,21 @@ class CnnDailymail(Evaluator):
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: Optional[str], num_samples: int,
|
||||
random_seed: int, rouge_path: Optional[str],
|
||||
apply_chat_template: bool, system_prompt: Optional[str],
|
||||
max_input_length: int, max_output_length: int) -> None:
|
||||
max_input_length: int, max_output_length: int,
|
||||
dump_path: Optional[str], dump_as_text: bool) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -126,6 +139,8 @@ class CnnDailymail(Evaluator):
|
||||
random_seed=random_seed,
|
||||
rouge_path=rouge_path,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
system_prompt=system_prompt,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
@ -35,7 +37,9 @@ class Evaluator(ABC):
|
||||
apply_chat_template: bool = False,
|
||||
fewshot_as_multiturn: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None):
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
@ -43,6 +47,8 @@ class Evaluator(ABC):
|
||||
self.fewshot_as_multiturn = fewshot_as_multiturn
|
||||
self.system_prompt = system_prompt
|
||||
self.chat_template_kwargs = chat_template_kwargs
|
||||
self.dump_path = dump_path
|
||||
self.dump_as_text = dump_as_text
|
||||
|
||||
@abstractmethod
|
||||
def generate_samples(self) -> Iterable[tuple]:
|
||||
@ -103,8 +109,18 @@ class Evaluator(ABC):
|
||||
references.append(reference)
|
||||
auxiliaries.append(aux)
|
||||
results = []
|
||||
task_id = 0
|
||||
if self.dump_path:
|
||||
self.dump_path = prepare_dump_path(self.dump_path)
|
||||
logger.info(f"Dumping data to {self.dump_path}")
|
||||
for output in tqdm(outputs, desc="Fetching responses"):
|
||||
results.append(output.result())
|
||||
res = output.result()
|
||||
results.append(res)
|
||||
dump_inference_result(self.dump_path, res, task_id,
|
||||
self.dump_as_text,
|
||||
getattr(llm, 'tokenizer', None))
|
||||
task_id += 1
|
||||
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
|
||||
@ -116,3 +132,54 @@ class Evaluator(ABC):
|
||||
@staticmethod
|
||||
def command(ctx, *args, **kwargs) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def prepare_dump_path(dump_path: str) -> str:
|
||||
if dump_path:
|
||||
if os.path.isdir(dump_path) or dump_path.endswith(os.sep):
|
||||
dump_path = os.path.join(dump_path, "dumped_data.json")
|
||||
os.makedirs(os.path.dirname(dump_path), exist_ok=True)
|
||||
if os.path.exists(dump_path):
|
||||
os.remove(dump_path)
|
||||
return dump_path
|
||||
|
||||
|
||||
def dump_inference_result(dump_path: str, result: RequestOutput, task_id: int,
|
||||
dump_as_text: bool, tokenizer: Any):
|
||||
if not dump_path:
|
||||
return
|
||||
try:
|
||||
with open(dump_path, "a") as f:
|
||||
input_ids = result.prompt_token_ids
|
||||
output_ids = result.outputs[0].token_ids
|
||||
|
||||
if tokenizer is None:
|
||||
logger.warning("Tokenizer not found, dumping raw token ids")
|
||||
dump_as_text = False
|
||||
|
||||
if dump_as_text:
|
||||
input_content = tokenizer.decode(input_ids)
|
||||
output_content = tokenizer.decode(output_ids)
|
||||
else:
|
||||
input_content = input_ids
|
||||
output_content = output_ids
|
||||
|
||||
if dump_as_text:
|
||||
data = {
|
||||
"task_id": task_id,
|
||||
"input_text": input_content,
|
||||
"output_text": output_content,
|
||||
"input_lens": len(input_content),
|
||||
"output_lens": len(output_content)
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
"task_id": task_id,
|
||||
"input_ids": input_ids,
|
||||
"output_ids": output_ids,
|
||||
"input_tokens": len(input_content),
|
||||
"output_tokens": len(output_content)
|
||||
}
|
||||
f.write(json.dumps(data) + "\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to dump data to {dump_path}: {e}")
|
||||
|
||||
@ -36,13 +36,17 @@ class JsonModeEval(Evaluator):
|
||||
num_samples: Optional[int] = None,
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = True,
|
||||
system_prompt: Optional[str] = None):
|
||||
system_prompt: Optional[str] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
if not apply_chat_template:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires apply_chat_template=True.")
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
system_prompt=system_prompt,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
if dataset_path is None:
|
||||
dataset_path = "NousResearch/json-mode-eval"
|
||||
self.data = datasets.load_dataset(dataset_path,
|
||||
@ -120,11 +124,20 @@ class JsonModeEval(Evaluator):
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: Optional[str], num_samples: int,
|
||||
random_seed: int, system_prompt: Optional[str],
|
||||
max_input_length: int, max_output_length: int) -> None:
|
||||
max_input_length: int, max_output_length: int,
|
||||
dump_path: Optional[str], dump_as_text: bool) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -133,6 +146,8 @@ class JsonModeEval(Evaluator):
|
||||
num_samples=num_samples,
|
||||
random_seed=random_seed,
|
||||
apply_chat_template=True,
|
||||
system_prompt=system_prompt)
|
||||
system_prompt=system_prompt,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
@ -39,7 +39,7 @@ from ..inputs.utils import apply_chat_template as trtllm_apply_chat_template
|
||||
from ..llmapi import RequestOutput
|
||||
from ..logger import logger
|
||||
from ..sampling_params import SamplingParams
|
||||
from .interface import Evaluator
|
||||
from .interface import Evaluator, dump_inference_result, prepare_dump_path
|
||||
|
||||
# NOTE: lm_eval uses "<image>" as the default image placeholder
|
||||
# https://github.com/EleutherAI/lm-evaluation-harness/blob/7f04db12d2f8e7a99a0830d99eb78130e1ba2122/lm_eval/models/hf_vlms.py#L25
|
||||
@ -54,12 +54,16 @@ class LmEvalWrapper(TemplateLM):
|
||||
streaming: bool = False,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
model_type: str | None = None,
|
||||
is_force_single_image: bool = False):
|
||||
is_force_single_image: bool = False,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
super().__init__()
|
||||
self.llm = llm
|
||||
self.sampling_params = sampling_params
|
||||
self.streaming = streaming
|
||||
self.chat_template_kwargs = chat_template_kwargs
|
||||
self.dump_path = dump_path
|
||||
self.dump_as_text = dump_as_text
|
||||
|
||||
@property
|
||||
def eot_token_id(self) -> int:
|
||||
@ -139,10 +143,16 @@ class LmEvalWrapper(TemplateLM):
|
||||
results.append(output)
|
||||
|
||||
outputs = []
|
||||
task_id = 0
|
||||
for output in tqdm(results,
|
||||
desc="Fetching responses",
|
||||
disable=disable_tqdm):
|
||||
outputs.append(output.result())
|
||||
res = output.result()
|
||||
outputs.append(res)
|
||||
dump_inference_result(self.dump_path, res, task_id,
|
||||
self.dump_as_text,
|
||||
getattr(self.llm, 'tokenizer', None))
|
||||
task_id += 1
|
||||
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
@ -167,7 +177,9 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
|
||||
max_images: int = 999,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
model_type: str | None = None,
|
||||
is_force_single_image: bool = False):
|
||||
is_force_single_image: bool = False,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
"""
|
||||
Initialize the multimodal wrapper.
|
||||
|
||||
@ -176,8 +188,15 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
|
||||
sampling_params: Parameters for text generation
|
||||
streaming: Whether to use streaming generation
|
||||
max_images: Maximum number of images per prompt (currently unlimited in TRT-LLM), set to 999 from lm_eval's default value.
|
||||
chat_template_kwargs: Chat template kwargs as JSON string
|
||||
dump_path: Path to dump data to ids for trtllm-bench usage
|
||||
dump_as_text: Whether to dump data to text
|
||||
"""
|
||||
super().__init__(llm, sampling_params, streaming)
|
||||
super().__init__(llm,
|
||||
sampling_params,
|
||||
streaming,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
|
||||
# NOTE: Required by lm_eval to identify this as a multimodal model
|
||||
self.MULTIMODAL = True
|
||||
@ -313,7 +332,11 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
|
||||
for output in tqdm(results,
|
||||
desc="Fetching responses",
|
||||
disable=disable_tqdm):
|
||||
outputs.append(output.result())
|
||||
res = output.result()
|
||||
outputs.append(res)
|
||||
dump_inference_result(self.dump_path, res, task_id,
|
||||
self.dump_as_text,
|
||||
getattr(self.llm, 'tokenizer', None))
|
||||
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
@ -334,7 +357,9 @@ class LmEvalEvaluator(Evaluator):
|
||||
fewshot_as_multiturn: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
is_multimodal: bool = False,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None):
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
try:
|
||||
import lm_eval
|
||||
except ImportError as e:
|
||||
@ -353,7 +378,9 @@ class LmEvalEvaluator(Evaluator):
|
||||
apply_chat_template=apply_chat_template,
|
||||
fewshot_as_multiturn=fewshot_as_multiturn,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
self.task_name = task_name
|
||||
self.dataset_path = dataset_path
|
||||
self.num_samples = num_samples
|
||||
@ -445,13 +472,20 @@ class LmEvalEvaluator(Evaluator):
|
||||
is_force_single_image: bool = False) -> float:
|
||||
import lm_eval
|
||||
lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper
|
||||
|
||||
if self.dump_path:
|
||||
self.dump_path = prepare_dump_path(self.dump_path)
|
||||
logger.info(f"Dumping data to {self.dump_path}")
|
||||
|
||||
results = lm_eval.evaluate(
|
||||
lm=lm_cls(llm,
|
||||
sampling_params=sampling_params,
|
||||
streaming=streaming,
|
||||
chat_template_kwargs=self.chat_template_kwargs,
|
||||
model_type=model_type,
|
||||
is_force_single_image=is_force_single_image),
|
||||
is_force_single_image=is_force_single_image,
|
||||
dump_path=self.dump_path,
|
||||
dump_as_text=self.dump_as_text),
|
||||
task_dict=self.task_dict,
|
||||
limit=self.num_samples,
|
||||
apply_chat_template=self.apply_chat_template,
|
||||
@ -491,7 +525,9 @@ class LmEvalEvaluator(Evaluator):
|
||||
system_prompt=kwargs.pop("system_prompt", None),
|
||||
is_multimodal=kwargs.pop("is_multimodal", False),
|
||||
chat_template_kwargs=kwargs.pop("chat_template_kwargs",
|
||||
None))
|
||||
None),
|
||||
dump_path=kwargs.pop("dump_path", None),
|
||||
dump_as_text=kwargs.pop("dump_as_text", False))
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=kwargs.pop("max_output_length"),
|
||||
truncate_prompt_tokens=kwargs.pop("max_input_length"),
|
||||
@ -548,6 +584,14 @@ class GSM8K(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -602,6 +646,14 @@ class GPQADiamond(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -652,6 +704,14 @@ class GPQAMain(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -702,6 +762,14 @@ class GPQAExtended(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -753,6 +821,14 @@ class MMMU(LmEvalEvaluator):
|
||||
default=
|
||||
512, # NOTE: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/mmmu/_template_yaml#L13
|
||||
help="Maximum generation length.")
|
||||
@click.option("--dump_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
|
||||
@ -62,12 +62,16 @@ class LongBenchV2(Evaluator):
|
||||
cot: bool = False,
|
||||
no_context: bool = False,
|
||||
rag: int = 0,
|
||||
max_len: int = 128000,
|
||||
max_input_length: int = 128000,
|
||||
max_output_length: int = 32000,
|
||||
output_dir: Optional[str] = None,
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None):
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
"""Initialize LongBench v2 evaluator.
|
||||
|
||||
Args:
|
||||
@ -81,17 +85,23 @@ class LongBenchV2(Evaluator):
|
||||
cot: Enable Chain-of-Thought reasoning
|
||||
no_context: Test without context (memorization test)
|
||||
rag: Number of top retrieved contexts to use (0 to disable)
|
||||
max_input_length: Maximum prompt length in tokens for truncation
|
||||
max_len: Maximum length (input + output) in tokens
|
||||
max_input_length: Maximum context length in tokens. If exceeds, the prompt will be truncated in the middle.
|
||||
max_output_length: Maximum output length in tokens for truncation
|
||||
output_dir: Directory to save evaluation results
|
||||
random_seed: Random seed for reproducibility
|
||||
apply_chat_template: Whether to apply model's chat template
|
||||
system_prompt: System prompt to prepend
|
||||
chat_template_kwargs: Chat template kwargs as JSON string
|
||||
dump_path: Path to dump data to ids for trtllm-bench usage.
|
||||
dump_as_text: Whether to dump data to text.
|
||||
"""
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
|
||||
self.dataset_path = dataset_path
|
||||
self.num_samples = num_samples
|
||||
@ -103,7 +113,8 @@ class LongBenchV2(Evaluator):
|
||||
self.no_context = no_context
|
||||
self.rag = rag
|
||||
self.output_dir = output_dir
|
||||
self.max_input_length = max_input_length
|
||||
# We need to minus max_output_length from max_len to reserve budget for output tokens.
|
||||
self.max_input_length = min(max_input_length, max_len - max_output_length)
|
||||
|
||||
# Will be set during evaluation
|
||||
self.tokenizer = None
|
||||
@ -305,7 +316,6 @@ class LongBenchV2(Evaluator):
|
||||
|
||||
If the prompt exceeds max_input_length, it takes the first half and last half
|
||||
to preserve both context beginning and end.
|
||||
We need to minus max_output_length from max_len to reserve budget for output tokens.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string to truncate
|
||||
@ -728,12 +738,16 @@ class LongBenchV2(Evaluator):
|
||||
default=None,
|
||||
help="System prompt.")
|
||||
@click.option(
|
||||
"--max_input_length",
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=128000,
|
||||
default=1024000,
|
||||
help=
|
||||
"Maximum prompt length before apply chat template. If exceeds, the prompt will be truncated in the middle."
|
||||
"Maximum length (input + output) in tokens which can be supported by the model."
|
||||
)
|
||||
@click.option("--max_input_length",
|
||||
type=int,
|
||||
default=128000,
|
||||
help="Maximum context length in tokens. If exceeds, the prompt will be truncated in the middle.")
|
||||
@click.option("--max_output_length",
|
||||
type=int,
|
||||
default=32000,
|
||||
@ -755,6 +769,14 @@ class LongBenchV2(Evaluator):
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="Top p for sampling.")
|
||||
@click.option("--dump_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: str, prompts_dir: Optional[str],
|
||||
@ -763,9 +785,10 @@ class LongBenchV2(Evaluator):
|
||||
cot: bool, no_context: bool, rag: int,
|
||||
output_dir: Optional[str], random_seed: int,
|
||||
apply_chat_template: bool, system_prompt: Optional[str],
|
||||
max_input_length: int, max_output_length: int,
|
||||
max_len: int, max_input_length: int, max_output_length: int,
|
||||
chat_template_kwargs: Optional[dict[str, Any]],
|
||||
temperature: float, top_p: float) -> None:
|
||||
temperature: float, top_p: float,
|
||||
dump_path: Optional[str], dump_as_text: bool) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=max_output_length,
|
||||
@ -782,12 +805,16 @@ class LongBenchV2(Evaluator):
|
||||
cot=cot,
|
||||
no_context=no_context,
|
||||
rag=rag,
|
||||
max_len=max_len,
|
||||
max_input_length=max_input_length,
|
||||
max_output_length=max_output_length,
|
||||
output_dir=output_dir,
|
||||
random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
@ -121,11 +121,15 @@ class MMLU(Evaluator):
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None):
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
dump_path: Optional[str] = None,
|
||||
dump_as_text: bool = False):
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
if dataset_path is None:
|
||||
dataset_path = self.dowload_dataset()
|
||||
self.dataset_path = dataset_path
|
||||
@ -302,6 +306,14 @@ class MMLU(Evaluator):
|
||||
help="Maximum generation length.")
|
||||
@click.option("--check_accuracy", is_flag=True, default=False)
|
||||
@click.option("--accuracy_threshold", type=float, default=30)
|
||||
@click.option("--dump_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to dump data to ids for trtllm-bench usage.")
|
||||
@click.option("--dump_as_text",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Whether to dump data to text.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: Optional[str], num_samples: int,
|
||||
@ -309,7 +321,8 @@ class MMLU(Evaluator):
|
||||
chat_template_kwargs: Optional[dict[str, Any]],
|
||||
system_prompt: Optional[str], max_input_length: int,
|
||||
max_output_length: int, check_accuracy: bool,
|
||||
accuracy_threshold: float) -> None:
|
||||
accuracy_threshold: float, dump_path: Optional[str],
|
||||
dump_as_text: bool) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -320,7 +333,9 @@ class MMLU(Evaluator):
|
||||
random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
dump_path=dump_path,
|
||||
dump_as_text=dump_as_text)
|
||||
accuracy = evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user