TensorRT-LLMs/tensorrt_llm/bench/dataset/prepare_real_data.py

306 lines
11 KiB
Python

import logging
import random
import re
import tempfile
from functools import partial
from typing import Optional
import click
from datasets import load_dataset
from PIL import Image
from pydantic import BaseModel, model_validator
from tensorrt_llm.bench.dataset.utils import (
generate_multimodal_dataset,
generate_text_dataset,
get_norm_dist_lengths,
write_dataset_to_file,
)
def validate_output_len_dist(ctx, param, value):
"""Validate the --output-len-dist option."""
if value is None:
return value
m = re.match(r"(\d+),(\d+)", value)
if m:
return int(m.group(1)), int(m.group(2))
else:
raise AssertionError(
"Incorrect specification for --output-len-dist. Correct format: "
"--output-len-dist <output_len_mean>,<output_len_stdev>"
)
class DatasetConfig(BaseModel):
"""Dataset configurations."""
"""Name of the dataset on HuggingFace."""
name: str
"""Config name of the dataset if existing."""
config_name: Optional[str] = None
"""Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits."""
split: Optional[str]
"""The dataset dictionary used for the input sentence."""
input_key: Optional[str] = None
"""The dataset dictionary key used for the prompt of the input sentence. Must not be set when prompt is set."""
image_key: Optional[str] = None
"""The dataset dictionary key used for the images."""
prompt_key: Optional[str] = None
"""The prompt sentence to be added to the input sentence. Must not be set when prompt_key is set."""
prompt: Optional[str] = None
"""The dataset dictionary key used to derive the output sequence length. Set to None if no output key."""
output_key: Optional[str]
@model_validator(mode="after")
def check_prompt(self) -> "DatasetConfig":
if self.prompt_key and self.prompt:
raise AssertionError("--prompt-key and --prompt cannot be set at the same time.")
if (not self.prompt_key) and (not self.prompt):
raise AssertionError("Either --prompt-key or --prompt must be set.")
return self
@property
def query(self):
"""Generate the query for HuggingFace `datasets.load_dataset()`."""
if self.config_name:
return [self.name, self.config_name]
else:
return [self.name]
def get_prompt(self, req):
"""Get the prompt sentence from the given request."""
if self.prompt_key:
assert self.prompt_key in req, (
f"Dataset {self.name} does not have key '{self.prompt_key}'. "
"Please set --prompt-key to one of the available keys: "
f"{req.keys()}"
)
return req[self.prompt_key]
else:
return self.prompt
def get_input(self, req):
"""Get the input sentence from the given request."""
assert self.input_key in req, (
f"Dataset {self.name} does not have key '{self.input_key}'. "
"Please set --input-key to one of the available keys: "
f"{req.keys()}"
)
return req[self.input_key]
def get_images(self, req):
"""Get the images from the given request."""
image_keys = [self.image_key] + [f"{self.image_key}_{i}" for i in range(1, 8)]
assert any(key in req for key in image_keys), (
f"Dataset {self.name} does not have key '{self.image_key}'. "
"Please set --dataset-image-key to one of the available keys: "
f"{req.keys()}"
)
images = []
for key in image_keys:
if key in req and req[key] is not None:
images.append(req[key])
return images
def get_output(self, req):
"""Get the output sentence from the given request."""
if self.output_key is None:
raise RuntimeError(
"--output-key is not set. Please either:\n"
"1. Define output length through --output-len-dist.\n"
f"2. If the dataset {self.name} has key for golden output and "
"you wish to set output length to the length of the golden "
"output, set --output-key."
)
assert self.output_key in req, (
f"Dataset {self.name} does not have key '{self.output_key}'. "
"Please set --output-key to one of the available keys: "
f"{req.keys()}"
)
return req[self.output_key]
def load_dataset_from_hf(dataset_config: DatasetConfig):
"""Load dataset from HuggingFace.
Args:
dataset_config: A `DatasetConfig` object that defines the dataset to load.
Returns:
Dataset iterator.
Raises:
ValueError: When dataset loading fails due to incorrect dataset config setting.
"""
try:
dataset = iter(
load_dataset(
*dataset_config.query,
split=dataset_config.split,
streaming=True,
trust_remote_code=True,
)
)
except ValueError as e:
if "Config" in e:
e += "\n Please add the config name to the dataset config yaml."
elif "split" in e:
e += "\n Please specify supported split in the dataset config yaml."
raise ValueError(e)
return dataset
@click.command(name="real-dataset")
@click.option("--dataset-name", required=True, type=str, help="Dataset name in HuggingFace.")
@click.option(
"--dataset-config-name",
type=str,
default=None,
help="Dataset config name in HuggingFace (if exists).",
)
@click.option("--dataset-split", type=str, required=True, help="Split of the dataset to use.")
@click.option("--dataset-input-key", type=str, help="The dataset dictionary key for input.")
@click.option(
"--dataset-image-key", type=str, default="image", help="The dataset dictionary key for images."
)
@click.option(
"--dataset-prompt-key",
type=str,
default=None,
help="The dataset dictionary key for prompt (if exists).",
)
@click.option(
"--dataset-prompt",
type=str,
default=None,
help="The prompt string when there is no prompt key for the dataset.",
)
@click.option(
"--dataset-output-key",
type=str,
default=None,
help="The dataset dictionary key for output (if exists).",
)
@click.option(
"--num-requests",
type=int,
default=None,
help="Number of requests to be generated. Will be capped to min(dataset.num_rows, num_requests).",
)
@click.option(
"--max-input-len",
type=int,
default=None,
help="Maximum input sequence length for a given request. This will be used to filter out the "
"requests with long input sequence length. Default will include all the requests.",
)
@click.option(
"--output-len-dist",
type=str,
default=None,
callback=validate_output_len_dist,
help="Output length distribution. Default will be the length of the golden output from "
"the dataset. Format: <output_len_mean>,<output_len_stdev>. E.g. 100,10 will randomize "
"the output length with mean=100 and variance=10.",
)
@click.pass_obj
def real_dataset(root_args, **kwargs):
"""Prepare dataset from real dataset."""
dataset_config = DatasetConfig(
**{k[8:]: v for k, v in kwargs.items() if k.startswith("dataset_")}
)
input_ids = []
input_lens = []
output_lens = []
task_ids = []
req_cnt = 0
modality = None
multimodal_texts = []
multimodal_image_paths = []
for req in load_dataset_from_hf(dataset_config):
if any(key in req for key in ["image", "image_1", "video"]):
# multimodal input
if "video" in req and req["video"] is not None:
assert "Not supported yet"
assert kwargs["output_len_dist"] is not None, (
"Output length distribution must be set for multimodal requests."
)
modality = "image"
text = dataset_config.get_prompt(req)
images = dataset_config.get_images(req)
image_paths = []
for image in images:
if image is not None:
if isinstance(image, str):
image_paths.append(image)
elif isinstance(image, Image.Image):
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
logging.debug(f"Saving image to {tmp_file.name}")
image = image.convert("RGB")
image.save(tmp_file, "JPEG")
filepath = tmp_file.name
image_paths.append(filepath)
else:
raise ValueError(f"Invalid image path: {image}")
multimodal_texts.append(text)
multimodal_image_paths.append(image_paths)
else:
# text input
prompt = dataset_config.get_prompt(req) + " " + dataset_config.get_input(req)
logging.debug(f"Input sequence: {prompt}")
line = root_args.tokenizer.encode(prompt)
if kwargs["max_input_len"] and len(line) > kwargs["max_input_len"]:
continue
input_ids.append(line)
input_lens.append(len(line))
# output if fetch from golden
if kwargs["output_len_dist"] is None:
output_lens.append(len(root_args.tokenizer.encode(dataset_config.get_output(req))))
# lora task id
task_id = root_args.task_id
if root_args.rand_task_id is not None:
min_id, max_id = root_args.rand_task_id
task_id = random.randint(min_id, max_id)
task_ids.append(task_id)
req_cnt += 1
if kwargs["num_requests"] and req_cnt >= kwargs["num_requests"]:
break
if (
kwargs["num_requests"]
and (len(input_ids) if modality is None else len(multimodal_texts)) < kwargs["num_requests"]
):
logging.warning(
f"Number of requests={len(input_ids) if modality is None else len(multimodal_texts)} is"
f" smaller than the num-requests user set={kwargs['num_requests']}."
)
# output if randomized
if kwargs["output_len_dist"] is not None:
osl_mean, osl_stdev = kwargs["output_len_dist"]
output_lens = get_norm_dist_lengths(
osl_mean,
osl_stdev,
len(input_ids) if modality is None else len(multimodal_texts),
root_args.random_seed,
)
logging.debug(f"Input lengths: {[len(i) for i in input_ids]}")
logging.debug(f"Output lengths: {output_lens}")
if modality is not None:
logging.debug(f"Modality: {modality}")
dataset_generator = None
if modality is not None:
dataset_generator = partial(
generate_multimodal_dataset, multimodal_texts, multimodal_image_paths
)
else:
dataset_generator = partial(generate_text_dataset, input_ids)
write_dataset_to_file(dataset_generator(output_lens), root_args.output)