mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 8d858f912e into 6df2c8a074
This commit is contained in:
commit
5f07c4e5e1
@ -34,10 +34,12 @@ class CnnDailymail(Evaluator):
|
||||
random_seed: int = 0,
|
||||
rouge_path: Optional[str] = None,
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None):
|
||||
system_prompt: Optional[str] = None,
|
||||
output_dir: Optional[str] = None):
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
system_prompt=system_prompt,
|
||||
output_dir=output_dir)
|
||||
if dataset_path is None:
|
||||
dataset_path = "ccdv/cnn_dailymail"
|
||||
self.data = datasets.load_dataset(dataset_path,
|
||||
@ -111,12 +113,17 @@ class CnnDailymail(Evaluator):
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: Optional[str], num_samples: int,
|
||||
random_seed: int, rouge_path: Optional[str],
|
||||
apply_chat_template: bool, system_prompt: Optional[str],
|
||||
max_input_length: int, max_output_length: int) -> None:
|
||||
max_input_length: int, max_output_length: int,
|
||||
output_dir: Optional[str]) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -126,6 +133,7 @@ class CnnDailymail(Evaluator):
|
||||
random_seed=random_seed,
|
||||
rouge_path=rouge_path,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
system_prompt=system_prompt,
|
||||
output_dir=output_dir)
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
@ -35,7 +37,8 @@ class Evaluator(ABC):
|
||||
apply_chat_template: bool = False,
|
||||
fewshot_as_multiturn: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None):
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
output_dir: Optional[str] = None):
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
@ -43,6 +46,7 @@ class Evaluator(ABC):
|
||||
self.fewshot_as_multiturn = fewshot_as_multiturn
|
||||
self.system_prompt = system_prompt
|
||||
self.chat_template_kwargs = chat_template_kwargs
|
||||
self.output_dir = output_dir
|
||||
|
||||
@abstractmethod
|
||||
def generate_samples(self) -> Iterable[tuple]:
|
||||
@ -105,6 +109,11 @@ class Evaluator(ABC):
|
||||
results = []
|
||||
for output in tqdm(outputs, desc="Fetching responses"):
|
||||
results.append(output.result())
|
||||
|
||||
if self.output_dir:
|
||||
dump_inference_results(self.output_dir, results,
|
||||
getattr(llm, 'tokenizer', None))
|
||||
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
|
||||
@ -116,3 +125,60 @@ class Evaluator(ABC):
|
||||
@staticmethod
|
||||
def command(ctx, *args, **kwargs) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def dump_inference_results(output_dir: str, results: List[dict],
|
||||
tokenizer: Any):
|
||||
if not output_dir:
|
||||
return
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Collect results
|
||||
results_list = []
|
||||
for task_id, result in enumerate(results):
|
||||
output_ids = result.outputs[0].token_ids
|
||||
output_text = result.outputs[0].text.strip()
|
||||
input_text = result.prompt.strip()
|
||||
input_ids = tokenizer.encode(input_text)
|
||||
results_list.append({
|
||||
"task_id": task_id,
|
||||
"input_ids": input_ids,
|
||||
"output_ids": output_ids,
|
||||
"input_text": input_text,
|
||||
"output_text": output_text
|
||||
})
|
||||
|
||||
# Dump token ids
|
||||
ids_path = os.path.join(output_dir, "dumped_ids.json")
|
||||
try:
|
||||
with open(ids_path, "w") as f:
|
||||
for item in results_list:
|
||||
data = {
|
||||
"task_id": item["task_id"],
|
||||
"input_ids": item["input_ids"],
|
||||
"output_ids": item["output_ids"],
|
||||
"input_tokens": len(item["input_ids"]),
|
||||
"output_tokens": len(item["output_ids"])
|
||||
}
|
||||
f.write(json.dumps(data) + "\n")
|
||||
logger.info(f"Dumped IDs to {ids_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to dump IDs to {ids_path}: {e}")
|
||||
|
||||
# Dump text
|
||||
text_path = os.path.join(output_dir, "dumped_text.json")
|
||||
try:
|
||||
with open(text_path, "w") as f:
|
||||
for item in results_list:
|
||||
data = {
|
||||
"task_id": item["task_id"],
|
||||
"input_text": item["input_text"],
|
||||
"output_text": item["output_text"],
|
||||
"input_len": len(item["input_text"]),
|
||||
"output_len": len(item["output_text"])
|
||||
}
|
||||
f.write(json.dumps(data) + "\n")
|
||||
logger.info(f"Dumped text to {text_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to dump text to {text_path}: {e}")
|
||||
|
||||
@ -36,13 +36,15 @@ class JsonModeEval(Evaluator):
|
||||
num_samples: Optional[int] = None,
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = True,
|
||||
system_prompt: Optional[str] = None):
|
||||
system_prompt: Optional[str] = None,
|
||||
output_dir: Optional[str] = None):
|
||||
if not apply_chat_template:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires apply_chat_template=True.")
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
system_prompt=system_prompt,
|
||||
output_dir=output_dir)
|
||||
if dataset_path is None:
|
||||
dataset_path = "NousResearch/json-mode-eval"
|
||||
self.data = datasets.load_dataset(dataset_path,
|
||||
@ -120,11 +122,16 @@ class JsonModeEval(Evaluator):
|
||||
type=int,
|
||||
default=512,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: Optional[str], num_samples: int,
|
||||
random_seed: int, system_prompt: Optional[str],
|
||||
max_input_length: int, max_output_length: int) -> None:
|
||||
max_input_length: int, max_output_length: int,
|
||||
output_dir: Optional[str]) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -133,6 +140,7 @@ class JsonModeEval(Evaluator):
|
||||
num_samples=num_samples,
|
||||
random_seed=random_seed,
|
||||
apply_chat_template=True,
|
||||
system_prompt=system_prompt)
|
||||
system_prompt=system_prompt,
|
||||
output_dir=output_dir)
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
@ -39,7 +39,7 @@ from ..inputs.utils import apply_chat_template as trtllm_apply_chat_template
|
||||
from ..llmapi import RequestOutput
|
||||
from ..logger import logger
|
||||
from ..sampling_params import SamplingParams
|
||||
from .interface import Evaluator
|
||||
from .interface import Evaluator, dump_inference_results
|
||||
|
||||
# NOTE: lm_eval uses "<image>" as the default image placeholder
|
||||
# https://github.com/EleutherAI/lm-evaluation-harness/blob/7f04db12d2f8e7a99a0830d99eb78130e1ba2122/lm_eval/models/hf_vlms.py#L25
|
||||
@ -54,12 +54,14 @@ class LmEvalWrapper(TemplateLM):
|
||||
streaming: bool = False,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
model_type: str | None = None,
|
||||
is_force_single_image: bool = False):
|
||||
is_force_single_image: bool = False,
|
||||
output_dir: Optional[str] = None):
|
||||
super().__init__()
|
||||
self.llm = llm
|
||||
self.sampling_params = sampling_params
|
||||
self.streaming = streaming
|
||||
self.chat_template_kwargs = chat_template_kwargs
|
||||
self.output_dir = output_dir
|
||||
|
||||
@property
|
||||
def eot_token_id(self) -> int:
|
||||
@ -144,6 +146,10 @@ class LmEvalWrapper(TemplateLM):
|
||||
disable=disable_tqdm):
|
||||
outputs.append(output.result())
|
||||
|
||||
if self.output_dir:
|
||||
dump_inference_results(self.output_dir, outputs,
|
||||
getattr(self.llm, 'tokenizer', None))
|
||||
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
|
||||
@ -167,7 +173,8 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
|
||||
max_images: int = 999,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
model_type: str | None = None,
|
||||
is_force_single_image: bool = False):
|
||||
is_force_single_image: bool = False,
|
||||
output_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the multimodal wrapper.
|
||||
|
||||
@ -176,8 +183,10 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
|
||||
sampling_params: Parameters for text generation
|
||||
streaming: Whether to use streaming generation
|
||||
max_images: Maximum number of images per prompt (currently unlimited in TRT-LLM), set to 999 from lm_eval's default value.
|
||||
chat_template_kwargs: Chat template kwargs as JSON string
|
||||
output_dir: Directory to save results
|
||||
"""
|
||||
super().__init__(llm, sampling_params, streaming)
|
||||
super().__init__(llm, sampling_params, streaming, output_dir=output_dir)
|
||||
|
||||
# NOTE: Required by lm_eval to identify this as a multimodal model
|
||||
self.MULTIMODAL = True
|
||||
@ -315,6 +324,10 @@ class MultimodalLmEvalWrapper(LmEvalWrapper):
|
||||
disable=disable_tqdm):
|
||||
outputs.append(output.result())
|
||||
|
||||
if self.output_dir:
|
||||
dump_inference_results(self.output_dir, outputs,
|
||||
getattr(self.llm, 'tokenizer', None))
|
||||
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
|
||||
@ -334,7 +347,8 @@ class LmEvalEvaluator(Evaluator):
|
||||
fewshot_as_multiturn: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
is_multimodal: bool = False,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None):
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
output_dir: Optional[str] = None):
|
||||
try:
|
||||
import lm_eval
|
||||
except ImportError as e:
|
||||
@ -353,7 +367,8 @@ class LmEvalEvaluator(Evaluator):
|
||||
apply_chat_template=apply_chat_template,
|
||||
fewshot_as_multiturn=fewshot_as_multiturn,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
output_dir=output_dir)
|
||||
self.task_name = task_name
|
||||
self.dataset_path = dataset_path
|
||||
self.num_samples = num_samples
|
||||
@ -445,13 +460,15 @@ class LmEvalEvaluator(Evaluator):
|
||||
is_force_single_image: bool = False) -> float:
|
||||
import lm_eval
|
||||
lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper
|
||||
|
||||
results = lm_eval.evaluate(
|
||||
lm=lm_cls(llm,
|
||||
sampling_params=sampling_params,
|
||||
streaming=streaming,
|
||||
chat_template_kwargs=self.chat_template_kwargs,
|
||||
model_type=model_type,
|
||||
is_force_single_image=is_force_single_image),
|
||||
is_force_single_image=is_force_single_image,
|
||||
output_dir=self.output_dir),
|
||||
task_dict=self.task_dict,
|
||||
limit=self.num_samples,
|
||||
apply_chat_template=self.apply_chat_template,
|
||||
@ -491,7 +508,8 @@ class LmEvalEvaluator(Evaluator):
|
||||
system_prompt=kwargs.pop("system_prompt", None),
|
||||
is_multimodal=kwargs.pop("is_multimodal", False),
|
||||
chat_template_kwargs=kwargs.pop("chat_template_kwargs",
|
||||
None))
|
||||
None),
|
||||
output_dir=kwargs.pop("output_dir", None))
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=kwargs.pop("max_output_length"),
|
||||
truncate_prompt_tokens=kwargs.pop("max_input_length"),
|
||||
@ -548,6 +566,10 @@ class GSM8K(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -602,6 +624,10 @@ class GPQADiamond(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -652,6 +678,10 @@ class GPQAMain(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -702,6 +732,10 @@ class GPQAExtended(LmEvalEvaluator):
|
||||
type=int,
|
||||
default=32768,
|
||||
help="Maximum generation length.")
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -753,6 +787,10 @@ class MMMU(LmEvalEvaluator):
|
||||
default=
|
||||
512, # NOTE: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/mmmu/_template_yaml#L13
|
||||
help="Maximum generation length.")
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -897,6 +935,10 @@ class LongBenchV1(LmEvalEvaluator):
|
||||
type=str,
|
||||
default=None,
|
||||
help="System prompt.")
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save the results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
@ -908,7 +950,8 @@ class LongBenchV1(LmEvalEvaluator):
|
||||
random_seed=kwargs.pop("random_seed", 0),
|
||||
apply_chat_template=kwargs.pop("apply_chat_template", True),
|
||||
system_prompt=kwargs.pop("system_prompt", None),
|
||||
chat_template_kwargs=kwargs.pop("chat_template_kwargs", None))
|
||||
chat_template_kwargs=kwargs.pop("chat_template_kwargs", None),
|
||||
output_dir=kwargs.pop("output_dir", None))
|
||||
|
||||
# Let lm-eval task configs control sampling via gen_kwargs.
|
||||
sampling_params = None
|
||||
|
||||
@ -62,7 +62,9 @@ class LongBenchV2(Evaluator):
|
||||
cot: bool = False,
|
||||
no_context: bool = False,
|
||||
rag: int = 0,
|
||||
max_len: int = 128000,
|
||||
max_input_length: int = 128000,
|
||||
max_output_length: int = 32000,
|
||||
output_dir: Optional[str] = None,
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = False,
|
||||
@ -81,7 +83,9 @@ class LongBenchV2(Evaluator):
|
||||
cot: Enable Chain-of-Thought reasoning
|
||||
no_context: Test without context (memorization test)
|
||||
rag: Number of top retrieved contexts to use (0 to disable)
|
||||
max_input_length: Maximum prompt length in tokens for truncation
|
||||
max_len: Maximum length (input + output) in tokens
|
||||
max_input_length: Maximum context length in tokens. If exceeds, the prompt will be truncated in the middle.
|
||||
max_output_length: Maximum output length in tokens for truncation
|
||||
output_dir: Directory to save evaluation results
|
||||
random_seed: Random seed for reproducibility
|
||||
apply_chat_template: Whether to apply model's chat template
|
||||
@ -91,7 +95,8 @@ class LongBenchV2(Evaluator):
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
output_dir=output_dir)
|
||||
|
||||
self.dataset_path = dataset_path
|
||||
self.num_samples = num_samples
|
||||
@ -103,7 +108,9 @@ class LongBenchV2(Evaluator):
|
||||
self.no_context = no_context
|
||||
self.rag = rag
|
||||
self.output_dir = output_dir
|
||||
self.max_input_length = max_input_length
|
||||
# We need to minus max_output_length from max_len to reserve budget for output tokens.
|
||||
self.max_input_length = min(max_input_length,
|
||||
max_len - max_output_length)
|
||||
|
||||
# Will be set during evaluation
|
||||
self.tokenizer = None
|
||||
@ -305,7 +312,6 @@ class LongBenchV2(Evaluator):
|
||||
|
||||
If the prompt exceeds max_input_length, it takes the first half and last half
|
||||
to preserve both context beginning and end.
|
||||
We need to minus max_output_length from max_len to reserve budget for output tokens.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string to truncate
|
||||
@ -727,12 +733,19 @@ class LongBenchV2(Evaluator):
|
||||
type=str,
|
||||
default=None,
|
||||
help="System prompt.")
|
||||
@click.option(
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=1024000,
|
||||
help=
|
||||
"Maximum length (input + output) in tokens which can be supported by the model."
|
||||
)
|
||||
@click.option(
|
||||
"--max_input_length",
|
||||
type=int,
|
||||
default=128000,
|
||||
help=
|
||||
"Maximum prompt length before apply chat template. If exceeds, the prompt will be truncated in the middle."
|
||||
"Maximum context length in tokens. If exceeds, the prompt will be truncated in the middle."
|
||||
)
|
||||
@click.option("--max_output_length",
|
||||
type=int,
|
||||
@ -763,7 +776,7 @@ class LongBenchV2(Evaluator):
|
||||
cot: bool, no_context: bool, rag: int,
|
||||
output_dir: Optional[str], random_seed: int,
|
||||
apply_chat_template: bool, system_prompt: Optional[str],
|
||||
max_input_length: int, max_output_length: int,
|
||||
max_len: int, max_input_length: int, max_output_length: int,
|
||||
chat_template_kwargs: Optional[dict[str, Any]],
|
||||
temperature: float, top_p: float) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
@ -782,7 +795,9 @@ class LongBenchV2(Evaluator):
|
||||
cot=cot,
|
||||
no_context=no_context,
|
||||
rag=rag,
|
||||
max_len=max_len,
|
||||
max_input_length=max_input_length,
|
||||
max_output_length=max_output_length,
|
||||
output_dir=output_dir,
|
||||
random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
|
||||
@ -121,11 +121,13 @@ class MMLU(Evaluator):
|
||||
random_seed: int = 0,
|
||||
apply_chat_template: bool = False,
|
||||
system_prompt: Optional[str] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None):
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
output_dir: Optional[str] = None):
|
||||
super().__init__(random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
output_dir=output_dir)
|
||||
if dataset_path is None:
|
||||
dataset_path = self.dowload_dataset()
|
||||
self.dataset_path = dataset_path
|
||||
@ -302,6 +304,10 @@ class MMLU(Evaluator):
|
||||
help="Maximum generation length.")
|
||||
@click.option("--check_accuracy", is_flag=True, default=False)
|
||||
@click.option("--accuracy_threshold", type=float, default=30)
|
||||
@click.option("--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save results.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: Optional[str], num_samples: int,
|
||||
@ -309,7 +315,7 @@ class MMLU(Evaluator):
|
||||
chat_template_kwargs: Optional[dict[str, Any]],
|
||||
system_prompt: Optional[str], max_input_length: int,
|
||||
max_output_length: int, check_accuracy: bool,
|
||||
accuracy_threshold: float) -> None:
|
||||
accuracy_threshold: float, output_dir: Optional[str]) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
@ -320,7 +326,8 @@ class MMLU(Evaluator):
|
||||
random_seed=random_seed,
|
||||
apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt,
|
||||
chat_template_kwargs=chat_template_kwargs)
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
output_dir=output_dir)
|
||||
accuracy = evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user