TensorRT-LLMs/benchmarks/cpp/utils/prepare_real_data.py
2024-12-16 21:50:47 -08:00

238 lines
8.5 KiB
Python

import logging
import random
import re
from typing import Optional
import click
from datasets import load_dataset
from pydantic import BaseModel, model_validator
from utils.utils import dataset_dump, get_norm_dist_lengths, print_dataset
def validate_output_len_dist(ctx, param, value):
"""Validate the --output-len-dist option."""
if value is None:
return value
m = re.match(r"(\d+),(\d+)", value)
if m:
return int(m.group(1)), int(m.group(2))
else:
raise AssertionError(
"Incorrect specification for --output-len-dist. Correct format: --output-len-dist <output_len_mean>,<output_len_stdev>"
)
class DatasetConfig(BaseModel):
"""Dataset configurations."""
"""Name of the dataset on HuggingFace."""
name: str
"""Config name of the dataset if existing."""
config_name: Optional[str] = None
"""Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits."""
split: Optional[str]
"""The dataset dictionary used for the input sentence."""
input_key: str
"""The dataset dictionary key used for the prompt of the input sentence. Must not be set when prompt is set."""
prompt_key: Optional[str] = None
"""The prompt sentence to be added to the input sentence. Must not be set when prompt_key is set."""
prompt: Optional[str] = None
"""The dataset dictionary key used to derive the output sequence length. Set to None if the dataset does not have a key for output."""
output_key: Optional[str]
@model_validator(mode='after')
def check_prompt(self) -> 'DatasetConfig':
if self.prompt_key and self.prompt:
raise AssertionError(
"--prompt-key and --prompt cannot be set at the same time.")
if (not self.prompt_key) and (not self.prompt):
raise AssertionError("Either --prompt-key or --prompt must be set.")
return self
@property
def query(self):
"""Generate the query for HuggingFace `datasets.load_dataset()`"""
if self.config_name:
return [self.name, self.config_name]
else:
return [self.name]
def get_prompt(self, req):
"""Get the prompt sentence from the given request."""
if self.prompt_key:
assert self.prompt_key in req, (
f"Dataset {self.name} does not have key '{self.prompt_key}'. "
"Please set --prompt-key to one of the available keys: "
f"{req.keys()}")
return req[self.prompt_key]
else:
return self.prompt
def get_input(self, req):
"""Get the input sentence from the given request."""
assert self.input_key in req, (
f"Dataset {self.name} does not have key '{self.input_key}'. "
"Please set --input-key to one of the available keys: "
f"{req.keys()}")
return req[self.input_key]
def get_output(self, req):
"""Get the output sentence from the given request."""
if self.output_key is None:
raise RuntimeError(
"--output-key is not set. Please either:\n"
"1. Define output length through --output-len-dist.\n"
f"2. If the dataset {self.name} has key for golden output and "
"you wish to set output length to the length of the golden "
"output, set --output-key.")
assert self.output_key in req, (
f"Dataset {self.name} does not have key '{self.output_key}'. "
"Please set --output-key to one of the available keys: "
f"{req.keys()}")
return req[self.output_key]
def load_dataset_from_hf(dataset_config: DatasetConfig):
"""Load dataset from HuggingFace.
Args:
dataset_config: A `DatasetConfig` object that defines the dataset to load.
Returns:
Dataset iterator.
Raises:
ValueError: When dataset loading fails due to incorrect dataset config setting.
"""
try:
dataset = iter(
load_dataset(*dataset_config.query,
split=dataset_config.split,
streaming=True))
except ValueError as e:
if "Config" in e:
e += "\n Please add the config name to the dataset config yaml."
elif "split" in e:
e += "\n Please specify supported split in the dataset config yaml."
raise ValueError(e)
return dataset
@click.command()
@click.option("--dataset-name",
required=True,
type=str,
help=f"Dataset name in HuggingFace.")
@click.option("--dataset-config-name",
type=str,
default=None,
help=f"Dataset config name in HuggingFace (if exists).")
@click.option("--dataset-split",
type=str,
required=True,
help=f"Split of the dataset to use.")
@click.option("--dataset-input-key",
required=True,
type=str,
help=f"The dataset dictionary key for input.")
@click.option("--dataset-prompt-key",
type=str,
default=None,
help=f"The dataset dictionary key for prompt (if exists).")
@click.option(
"--dataset-prompt",
type=str,
default=None,
help=f"The prompt string when there is no prompt key for the dataset.")
@click.option("--dataset-output-key",
type=str,
default=None,
help=f"The dataset dictionary key for output (if exists).")
@click.option(
"--num-requests",
type=int,
default=None,
help=
"Number of requests to be generated. Will be capped to min(dataset.num_rows, num_requests)."
)
@click.option(
"--max-input-len",
type=int,
default=None,
help=
"Maximum input sequence length for a given request. This will be used to filter out the requests with long input sequence length. Default will include all the requests."
)
@click.option(
"--output-len-dist",
type=str,
default=None,
callback=validate_output_len_dist,
help=
"Output length distribution. Default will be the length of the golden output from the dataset. Format: <output_len_mean>,<output_len_stdev>. E.g. 100,10 will randomize the output length with mean=100 and variance=10."
)
@click.pass_obj
def dataset(root_args, **kwargs):
"""Prepare dataset from real dataset."""
dataset_config = DatasetConfig(**{
k[8:]: v
for k, v in kwargs.items() if k.startswith('dataset_')
})
input_ids = []
input_lens = []
output_lens = []
task_ids = []
req_cnt = 0
for req in load_dataset_from_hf(dataset_config):
# input
prompt = dataset_config.get_prompt(
req) + ' ' + dataset_config.get_input(req)
logging.debug(f"Input sequence: {prompt}")
line = root_args.tokenizer.encode(prompt)
if kwargs['max_input_len'] and len(line) > kwargs['max_input_len']:
continue
input_ids.append(line)
input_lens.append(len(line))
# output if fetch from golden
if kwargs['output_len_dist'] is None:
output_lens.append(
len(root_args.tokenizer.encode(dataset_config.get_output(req))))
# lora task id
task_id = root_args.task_id
if root_args.rand_task_id is not None:
min_id, max_id = root_args.rand_task_id
task_id = random.randint(min_id, max_id)
task_ids.append(task_id)
req_cnt += 1
if kwargs['num_requests'] and req_cnt >= kwargs['num_requests']:
break
if kwargs['num_requests'] and len(input_ids) < kwargs['num_requests']:
logging.warning(
"Number of requests is smaller than the num-requests user set.")
# output if randomized
if kwargs['output_len_dist'] is not None:
osl_mean, osl_stdev = kwargs['output_len_dist']
output_lens = get_norm_dist_lengths(osl_mean, osl_stdev, len(input_ids),
root_args.random_seed)
logging.debug(f"Input lengths: {[len(i) for i in input_ids]}")
logging.debug(f"Output lengths: {output_lens}")
if not root_args.std_out:
dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "dataset",
"tokenizer": root_args.tokenizer.__class__.__name__,
"num_requests": len(input_ids),
"max_input_len": max(input_lens),
"max_output_len": max(output_lens)
}, root_args.output)
else:
print_dataset(
input_ids,
output_lens,
)