TensorRT-LLMs/tensorrt_llm/bench/dataset/prepare_synthetic_data.py

105 lines
4.0 KiB
Python

import random
import warnings
import click
from tensorrt_llm.bench.dataset.utils import (
gen_random_tokens,
generate_text_dataset,
get_norm_dist_lengths,
get_unif_dist_lengths,
write_dataset_to_file,
)
def _generate_task_ids_and_lora_config(root_args, num_reqs):
"""Generate task IDs and determine LoRA configuration based on root_args."""
if root_args.rand_task_id is None:
task_ids = [root_args.task_id for _ in range(num_reqs)]
else:
min_id, max_id = root_args.rand_task_id
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
use_task_ids = root_args.task_id != -1 or root_args.rand_task_id is not None
# Determine if LoRA should be used (requires both task IDs and lora_dir)
use_lora = use_task_ids and root_args.lora_dir is not None
# Warn if task IDs are specified but no LoRA directory is provided
if use_task_ids and not use_lora:
warnings.warn(
"Task IDs require LoRA directory. Use --lora-dir or omit task IDs.", UserWarning
)
return (
task_ids,
task_ids if use_task_ids else None,
{"lora_dir": root_args.lora_dir} if use_lora else None,
)
@click.command()
@click.option("--num-requests", required=True, type=int, help="Number of requests to be generated")
@click.option("--input-mean", required=True, type=int, help="normal dist mean for input tokens")
@click.option("--input-stdev", required=True, type=int, help="normal dist stdev for input tokens")
@click.option("--output-mean", required=True, type=int, help="normal dist mean for output tokens")
@click.option("--output-stdev", required=True, type=int, help="normal dist stdev for output tokens")
@click.pass_obj
def token_norm_dist(root_args, **kwargs):
"""Prepare synthetic dataset by generating random tokens with normal dist lengths."""
input_ids = []
input_lens = []
output_lens = []
input_lens = get_norm_dist_lengths(
kwargs["input_mean"], kwargs["input_stdev"], kwargs["num_requests"], root_args.random_seed
)
num_reqs = len(input_lens)
output_lens = get_norm_dist_lengths(
kwargs["output_mean"], kwargs["output_stdev"], num_reqs, root_args.random_seed
)
input_ids = gen_random_tokens(input_lens, root_args.tokenizer, root_args.random_seed)
_, print_task_ids, lora_config = _generate_task_ids_and_lora_config(root_args, num_reqs)
dataset_generator = generate_text_dataset(
input_ids, output_lens, task_ids=print_task_ids, lora_config=lora_config
)
write_dataset_to_file(dataset_generator, root_args.output)
@click.command()
@click.option("--num-requests", required=True, type=int, help="Number of requests to be generated")
@click.option(
"--input-min", required=True, type=int, help="uniform dist (inclusive) min for input tokens"
)
@click.option(
"--input-max", required=True, type=int, help="normal dist (inclusive) max for input tokens"
)
@click.option(
"--output-min", required=True, type=int, help="normal dist (inclusive) min for output tokens"
)
@click.option(
"--output-max", required=True, type=int, help="normal dist (inclusive) max for output tokens"
)
@click.pass_obj
def token_unif_dist(root_args, **kwargs):
"""Prepare synthetic dataset by generating random tokens with normal uniformly lengths."""
input_ids = []
input_lens = []
output_lens = []
input_lens = get_unif_dist_lengths(
kwargs["input_min"], kwargs["input_max"], kwargs["num_requests"], root_args.random_seed
)
num_reqs = len(input_lens)
output_lens = get_unif_dist_lengths(
kwargs["output_min"], kwargs["output_max"], num_reqs, root_args.random_seed
)
input_ids = gen_random_tokens(input_lens, root_args.tokenizer, root_args.random_seed)
_, print_task_ids, lora_config = _generate_task_ids_and_lora_config(root_args, num_reqs)
dataset_generator = generate_text_dataset(
input_ids, output_lens, task_ids=print_task_ids, lora_config=lora_config
)
write_dataset_to_file(dataset_generator, root_args.output)