mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
206 lines
7.9 KiB
Python
206 lines
7.9 KiB
Python
import json
|
|
from functools import partial
|
|
from typing import List, TextIO, Tuple
|
|
|
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
|
|
|
from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata,
|
|
InferenceRequest)
|
|
from tensorrt_llm.bench.dataclasses.statistics import PercentileStats
|
|
from tensorrt_llm.executor.request import LoRARequest
|
|
from tensorrt_llm.inputs import default_multimodal_input_loader
|
|
|
|
|
|
def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer:
|
|
"""Initialize a tokenizer.
|
|
|
|
Args:
|
|
model_name (str): The name of the HuggingFace model to pull a
|
|
tokenizer from.
|
|
|
|
Returns:
|
|
PreTrainedTokenizer: An initialized HuggingFace tokenizer.
|
|
"""
|
|
# Initialize the tokenizer specific to the model that we are planning
|
|
# to benchmark.
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
|
padding_side="left",
|
|
trust_remote_code=True)
|
|
if tokenizer.pad_token_id is None:
|
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
|
|
return tokenizer
|
|
|
|
|
|
def create_dataset_from_stream(
|
|
tokenizer: PreTrainedTokenizer,
|
|
stream: TextIO,
|
|
max_input_length: int = 0,
|
|
max_output_length: int = 0,
|
|
num_requests: int = 0,
|
|
model_dir: str = None,
|
|
model_type: str = None,
|
|
modality: str = None,
|
|
image_data_format: str = "pt",
|
|
data_device: str = "cpu",
|
|
max_input_seq_len_for_multimodal: int = 4096,
|
|
) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
|
|
"""Generate metadata and a list of requests to drive benchmarking.
|
|
|
|
Args:
|
|
tokenizer (PreTrainedTokenizer): HuggingFace tokenizer.
|
|
stream (TextIO): Stream of input requests.
|
|
max_input_length (int, optional): Maximum input length to cap prompts to. Defaults to 0.
|
|
max_output_length (int, optional): Maximum output length to cap prompts to.. Defaults to 0.
|
|
num_requests (int, optional): Number of requests to limit to. Defaults to 0.
|
|
|
|
Returns:
|
|
Tuple[DatasetMetadata, List[InferenceRequest]]: A tuple containing a dataclass of dataset
|
|
statistics and a list of inference requests for benchmarking.
|
|
"""
|
|
# Initialize dataset list, and metadata tracking variables.
|
|
dataset = []
|
|
max_requests = num_requests if num_requests > 0 else float("inf")
|
|
|
|
# If we're limiting the input length to a certain size, then set up
|
|
# a partial to truncate the data down to size. Otherwise, just use the
|
|
# unmodified tokenizer callable.
|
|
tokenize = (partial(
|
|
tokenizer,
|
|
padding="max_length",
|
|
max_length=max_input_length,
|
|
truncation=True,
|
|
) if max_input_length > 0 else tokenizer)
|
|
|
|
# If we need to limit the output length, fill in a partial callable
|
|
# for max, otherwise a lambda that just returns x with no bounds.
|
|
output_limiter = (partial(max, max_output_length)
|
|
if max_output_length > 0 else lambda x: x)
|
|
|
|
# For each line in the standard input, parse out the JSON string we expect
|
|
# to see.
|
|
# Note the := walrus -- we're assigning and checking the condition.
|
|
all_osl = []
|
|
prompts = []
|
|
media_paths = []
|
|
all_logits = []
|
|
task_ids = []
|
|
lora_requests = []
|
|
while (line := stream.readline()) and len(task_ids) < max_requests:
|
|
# We expect the data to come in as a JSON string.
|
|
# For example:
|
|
# {"task_id": 1, "prompt": "Generate an infinite response to the following:
|
|
# There once was a man who.", "output_tokens": 1000}
|
|
#
|
|
# For multimodal data, the data should be of the form:
|
|
# {"task_id": 1, "prompt": "Generate an infinite response to the following:
|
|
# There once was a man who.", "output_tokens": 1000,
|
|
# "media_paths": ["/path/to/image1.jpg", "/path/to/image2.jpg"]}
|
|
#
|
|
# For LoRA data, the data should be of the form:
|
|
# {"task_id": 1, "prompt": "Generate an infinite response to the following:
|
|
# There once was a man who.", "output_tokens": 1000,
|
|
# "lora_request": {"lora_name": "my_lora", "lora_int_id": 1, "lora_path": "/path/to/lora"}}
|
|
#
|
|
# Each line should be a complete JSON dictionary with no indentation
|
|
# or newline characters. The task_id field is required.
|
|
data = json.loads(line)
|
|
prompts.append(data.get("prompt"))
|
|
media_paths.append(data.get("media_paths", None))
|
|
all_logits.append(data.get("input_ids", data.get("logits", None)))
|
|
all_osl.append(data.get("output_tokens"))
|
|
task_ids.append(data.get("task_id"))
|
|
|
|
# Parse LoRA request if present
|
|
lora_data = data.get("lora_request", None)
|
|
if lora_data:
|
|
lora_request = LoRARequest(lora_name=lora_data["lora_name"],
|
|
lora_int_id=lora_data["lora_int_id"],
|
|
lora_path=lora_data.get("lora_path", ""))
|
|
lora_requests.append(lora_request)
|
|
else:
|
|
lora_requests.append(None)
|
|
|
|
if modality is not None:
|
|
# Multimodal data need extra preprocessing
|
|
assert modality in [
|
|
"image", "video"
|
|
], f"Modality must be one of ['image', 'video'] but got {modality}."
|
|
prompts = default_multimodal_input_loader(
|
|
tokenizer=tokenizer,
|
|
model_dir=model_dir,
|
|
model_type=model_type,
|
|
modality=modality,
|
|
prompts=prompts,
|
|
media=media_paths, # list of dicts
|
|
image_data_format=image_data_format,
|
|
device=data_device)
|
|
|
|
all_isl = []
|
|
all_seq_len = []
|
|
for prompt, logits, osl, task_id, lora_request in zip(
|
|
prompts, all_logits, all_osl, task_ids, lora_requests):
|
|
if modality is not None:
|
|
# NOTE: we cannot tokenize multi-modal data, handled by preprocessor
|
|
# so the actual sequence length is unknown until the model is run
|
|
logits = None
|
|
cur_isl = max_input_seq_len_for_multimodal
|
|
else:
|
|
# If the request comes in with logits, just use the provided.
|
|
# Otherwise we need to tokenize it.
|
|
logits = tokenize(prompt)["input_ids"] if logits is None else logits
|
|
cur_isl = len(logits)
|
|
all_isl.append(cur_isl)
|
|
all_seq_len.append(cur_isl + osl)
|
|
|
|
request = InferenceRequest(
|
|
task_id=task_id,
|
|
prompt=prompt,
|
|
output_tokens=output_limiter(osl),
|
|
input_ids=logits,
|
|
lora_request=lora_request,
|
|
)
|
|
dataset.append(request)
|
|
|
|
isl_stats = PercentileStats.from_iterable(all_isl)
|
|
osl_stats = PercentileStats.from_iterable(all_osl)
|
|
seq_len_stats = PercentileStats.from_iterable(all_seq_len)
|
|
|
|
# Fill in basic dataset metrics here
|
|
metadata = DatasetMetadata(
|
|
isl_stats=isl_stats,
|
|
osl_stats=osl_stats,
|
|
seq_len_stats=seq_len_stats,
|
|
num_requests=len(dataset),
|
|
)
|
|
|
|
return metadata, dataset
|
|
|
|
|
|
def update_metadata_for_multimodal(metadata, statistics) -> DatasetMetadata:
|
|
"""Update the metadata from benchmark statistics. Only used for multimodal models.
|
|
|
|
Args:
|
|
metadata (DatasetMetadata): The metadata to update.
|
|
statistics (StatsKeeper): The statistics to update the metadata with.
|
|
|
|
Returns:
|
|
DatasetMetadata: The updated metadata.
|
|
"""
|
|
all_isl = []
|
|
all_osl = []
|
|
all_seq_len = []
|
|
for request in statistics.requests.values():
|
|
all_isl.append(request.num_input_tokens)
|
|
all_osl.append(request.num_total_output_tokens)
|
|
all_seq_len.append(request.num_input_tokens +
|
|
request.num_total_output_tokens)
|
|
isl_stats = PercentileStats.from_iterable(all_isl)
|
|
osl_stats = PercentileStats.from_iterable(all_osl)
|
|
seq_len_stats = PercentileStats.from_iterable(all_seq_len)
|
|
metadata.isl_stats = isl_stats
|
|
metadata.osl_stats = osl_stats
|
|
metadata.seq_len_stats = seq_len_stats
|
|
|
|
return metadata
|