mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 10:11:47 +08:00
105 lines
4.0 KiB
Python
105 lines
4.0 KiB
Python
import random
|
|
import warnings
|
|
|
|
import click
|
|
|
|
from tensorrt_llm.bench.dataset.utils import (
|
|
gen_random_tokens,
|
|
generate_text_dataset,
|
|
get_norm_dist_lengths,
|
|
get_unif_dist_lengths,
|
|
write_dataset_to_file,
|
|
)
|
|
|
|
|
|
def _generate_task_ids_and_lora_config(root_args, num_reqs):
|
|
"""Generate task IDs and determine LoRA configuration based on root_args."""
|
|
if root_args.rand_task_id is None:
|
|
task_ids = [root_args.task_id for _ in range(num_reqs)]
|
|
else:
|
|
min_id, max_id = root_args.rand_task_id
|
|
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
|
|
|
|
use_task_ids = root_args.task_id != -1 or root_args.rand_task_id is not None
|
|
|
|
# Determine if LoRA should be used (requires both task IDs and lora_dir)
|
|
use_lora = use_task_ids and root_args.lora_dir is not None
|
|
|
|
# Warn if task IDs are specified but no LoRA directory is provided
|
|
if use_task_ids and not use_lora:
|
|
warnings.warn(
|
|
"Task IDs require LoRA directory. Use --lora-dir or omit task IDs.", UserWarning
|
|
)
|
|
|
|
return (
|
|
task_ids,
|
|
task_ids if use_task_ids else None,
|
|
{"lora_dir": root_args.lora_dir} if use_lora else None,
|
|
)
|
|
|
|
|
|
@click.command()
|
|
@click.option("--num-requests", required=True, type=int, help="Number of requests to be generated")
|
|
@click.option("--input-mean", required=True, type=int, help="normal dist mean for input tokens")
|
|
@click.option("--input-stdev", required=True, type=int, help="normal dist stdev for input tokens")
|
|
@click.option("--output-mean", required=True, type=int, help="normal dist mean for output tokens")
|
|
@click.option("--output-stdev", required=True, type=int, help="normal dist stdev for output tokens")
|
|
@click.pass_obj
|
|
def token_norm_dist(root_args, **kwargs):
|
|
"""Prepare synthetic dataset by generating random tokens with normal dist lengths."""
|
|
input_ids = []
|
|
input_lens = []
|
|
output_lens = []
|
|
|
|
input_lens = get_norm_dist_lengths(
|
|
kwargs["input_mean"], kwargs["input_stdev"], kwargs["num_requests"], root_args.random_seed
|
|
)
|
|
|
|
num_reqs = len(input_lens)
|
|
output_lens = get_norm_dist_lengths(
|
|
kwargs["output_mean"], kwargs["output_stdev"], num_reqs, root_args.random_seed
|
|
)
|
|
input_ids = gen_random_tokens(input_lens, root_args.tokenizer, root_args.random_seed)
|
|
_, print_task_ids, lora_config = _generate_task_ids_and_lora_config(root_args, num_reqs)
|
|
dataset_generator = generate_text_dataset(
|
|
input_ids, output_lens, task_ids=print_task_ids, lora_config=lora_config
|
|
)
|
|
write_dataset_to_file(dataset_generator, root_args.output)
|
|
|
|
|
|
@click.command()
|
|
@click.option("--num-requests", required=True, type=int, help="Number of requests to be generated")
|
|
@click.option(
|
|
"--input-min", required=True, type=int, help="uniform dist (inclusive) min for input tokens"
|
|
)
|
|
@click.option(
|
|
"--input-max", required=True, type=int, help="normal dist (inclusive) max for input tokens"
|
|
)
|
|
@click.option(
|
|
"--output-min", required=True, type=int, help="normal dist (inclusive) min for output tokens"
|
|
)
|
|
@click.option(
|
|
"--output-max", required=True, type=int, help="normal dist (inclusive) max for output tokens"
|
|
)
|
|
@click.pass_obj
|
|
def token_unif_dist(root_args, **kwargs):
|
|
"""Prepare synthetic dataset by generating random tokens with normal uniformly lengths."""
|
|
input_ids = []
|
|
input_lens = []
|
|
output_lens = []
|
|
|
|
input_lens = get_unif_dist_lengths(
|
|
kwargs["input_min"], kwargs["input_max"], kwargs["num_requests"], root_args.random_seed
|
|
)
|
|
|
|
num_reqs = len(input_lens)
|
|
output_lens = get_unif_dist_lengths(
|
|
kwargs["output_min"], kwargs["output_max"], num_reqs, root_args.random_seed
|
|
)
|
|
input_ids = gen_random_tokens(input_lens, root_args.tokenizer, root_args.random_seed)
|
|
_, print_task_ids, lora_config = _generate_task_ids_and_lora_config(root_args, num_reqs)
|
|
dataset_generator = generate_text_dataset(
|
|
input_ids, output_lens, task_ids=print_task_ids, lora_config=lora_config
|
|
)
|
|
write_dataset_to_file(dataset_generator, root_args.output)
|