mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: adding multimodal (only image for now) support in trtllm-bench (#3490)
* feat: adding multimodal (only image for now) support in trtllm-bench Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * fix: add in load_dataset() calls to maintain the v2.19.2 behavior Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * re-adding prompt_token_ids and using that for prompt_len Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * updating the datasets version in examples as well Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * api changes are not needed Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * moving datasets requirement and removing a missed api change Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * addressing review comments Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * refactoring the quickstart example Signed-off-by: Rakib Hasan <rhasan@nvidia.com> --------- Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
parent
26ebd95302
commit
ff3b741045
@ -1,12 +1,16 @@
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import tempfile
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from PIL import Image
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
from utils.utils import dataset_dump, get_norm_dist_lengths, print_dataset
|
from utils.utils import (get_norm_dist_lengths, multimodal_dataset_dump,
|
||||||
|
print_multimodal_dataset, print_text_dataset,
|
||||||
|
text_dataset_dump)
|
||||||
|
|
||||||
|
|
||||||
def validate_output_len_dist(ctx, param, value):
|
def validate_output_len_dist(ctx, param, value):
|
||||||
@ -31,8 +35,10 @@ class DatasetConfig(BaseModel):
|
|||||||
"""Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits."""
|
"""Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits."""
|
||||||
split: Optional[str]
|
split: Optional[str]
|
||||||
"""The dataset dictionary used for the input sentence."""
|
"""The dataset dictionary used for the input sentence."""
|
||||||
input_key: str
|
input_key: Optional[str] = None
|
||||||
"""The dataset dictionary key used for the prompt of the input sentence. Must not be set when prompt is set."""
|
"""The dataset dictionary key used for the prompt of the input sentence. Must not be set when prompt is set."""
|
||||||
|
image_key: Optional[str] = None
|
||||||
|
"""The dataset dictionary key used for the images."""
|
||||||
prompt_key: Optional[str] = None
|
prompt_key: Optional[str] = None
|
||||||
"""The prompt sentence to be added to the input sentence. Must not be set when prompt_key is set."""
|
"""The prompt sentence to be added to the input sentence. Must not be set when prompt_key is set."""
|
||||||
prompt: Optional[str] = None
|
prompt: Optional[str] = None
|
||||||
@ -75,6 +81,20 @@ class DatasetConfig(BaseModel):
|
|||||||
f"{req.keys()}")
|
f"{req.keys()}")
|
||||||
return req[self.input_key]
|
return req[self.input_key]
|
||||||
|
|
||||||
|
def get_images(self, req):
|
||||||
|
"""Get the images from the given request."""
|
||||||
|
image_keys = [self.image_key
|
||||||
|
] + [f"{self.image_key}_{i}" for i in range(1, 8)]
|
||||||
|
assert any(key in req for key in image_keys), (
|
||||||
|
f"Dataset {self.name} does not have key '{self.image_key}'. "
|
||||||
|
"Please set --dataset-image-key to one of the available keys: "
|
||||||
|
f"{req.keys()}")
|
||||||
|
images = []
|
||||||
|
for key in image_keys:
|
||||||
|
if key in req and req[key] is not None:
|
||||||
|
images.append(req[key])
|
||||||
|
return images
|
||||||
|
|
||||||
def get_output(self, req):
|
def get_output(self, req):
|
||||||
"""Get the output sentence from the given request."""
|
"""Get the output sentence from the given request."""
|
||||||
if self.output_key is None:
|
if self.output_key is None:
|
||||||
@ -105,7 +125,8 @@ def load_dataset_from_hf(dataset_config: DatasetConfig):
|
|||||||
dataset = iter(
|
dataset = iter(
|
||||||
load_dataset(*dataset_config.query,
|
load_dataset(*dataset_config.query,
|
||||||
split=dataset_config.split,
|
split=dataset_config.split,
|
||||||
streaming=True))
|
streaming=True,
|
||||||
|
trust_remote_code=True))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if "Config" in e:
|
if "Config" in e:
|
||||||
e += "\n Please add the config name to the dataset config yaml."
|
e += "\n Please add the config name to the dataset config yaml."
|
||||||
@ -130,9 +151,12 @@ def load_dataset_from_hf(dataset_config: DatasetConfig):
|
|||||||
required=True,
|
required=True,
|
||||||
help=f"Split of the dataset to use.")
|
help=f"Split of the dataset to use.")
|
||||||
@click.option("--dataset-input-key",
|
@click.option("--dataset-input-key",
|
||||||
required=True,
|
|
||||||
type=str,
|
type=str,
|
||||||
help=f"The dataset dictionary key for input.")
|
help=f"The dataset dictionary key for input.")
|
||||||
|
@click.option("--dataset-image-key",
|
||||||
|
type=str,
|
||||||
|
default="image",
|
||||||
|
help=f"The dataset dictionary key for images.")
|
||||||
@click.option("--dataset-prompt-key",
|
@click.option("--dataset-prompt-key",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
@ -181,21 +205,54 @@ def dataset(root_args, **kwargs):
|
|||||||
output_lens = []
|
output_lens = []
|
||||||
task_ids = []
|
task_ids = []
|
||||||
req_cnt = 0
|
req_cnt = 0
|
||||||
|
modality = None
|
||||||
|
multimodal_texts = []
|
||||||
|
multimodal_image_paths = []
|
||||||
for req in load_dataset_from_hf(dataset_config):
|
for req in load_dataset_from_hf(dataset_config):
|
||||||
# input
|
if any(key in req for key in ['image', 'image_1', 'video']):
|
||||||
prompt = dataset_config.get_prompt(
|
# multimodal input
|
||||||
req) + ' ' + dataset_config.get_input(req)
|
if 'video' in req and req['video'] is not None:
|
||||||
logging.debug(f"Input sequence: {prompt}")
|
assert "Not supported yet"
|
||||||
line = root_args.tokenizer.encode(prompt)
|
assert kwargs['output_len_dist'] is not None, (
|
||||||
if kwargs['max_input_len'] and len(line) > kwargs['max_input_len']:
|
"Output length distribution must be set for multimodal requests."
|
||||||
continue
|
)
|
||||||
input_ids.append(line)
|
modality = 'image'
|
||||||
input_lens.append(len(line))
|
text = dataset_config.get_prompt(req)
|
||||||
|
images = dataset_config.get_images(req)
|
||||||
|
image_paths = []
|
||||||
|
for image in images:
|
||||||
|
if image is not None:
|
||||||
|
if isinstance(image, str):
|
||||||
|
image_paths.append(image)
|
||||||
|
elif isinstance(image, Image.Image):
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
suffix=".jpg", delete=False) as tmp_file:
|
||||||
|
logging.debug(f"Saving image to {tmp_file.name}")
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image.save(tmp_file, "JPEG")
|
||||||
|
filepath = tmp_file.name
|
||||||
|
image_paths.append(filepath)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid image path: {image}")
|
||||||
|
multimodal_texts.append(text)
|
||||||
|
multimodal_image_paths.append(image_paths)
|
||||||
|
else:
|
||||||
|
# text input
|
||||||
|
prompt = dataset_config.get_prompt(
|
||||||
|
req) + ' ' + dataset_config.get_input(req)
|
||||||
|
logging.debug(f"Input sequence: {prompt}")
|
||||||
|
line = root_args.tokenizer.encode(prompt)
|
||||||
|
if kwargs['max_input_len'] and len(line) > kwargs['max_input_len']:
|
||||||
|
continue
|
||||||
|
input_ids.append(line)
|
||||||
|
input_lens.append(len(line))
|
||||||
|
|
||||||
# output if fetch from golden
|
# output if fetch from golden
|
||||||
if kwargs['output_len_dist'] is None:
|
if kwargs['output_len_dist'] is None:
|
||||||
output_lens.append(
|
output_lens.append(
|
||||||
len(root_args.tokenizer.encode(dataset_config.get_output(req))))
|
len(
|
||||||
|
root_args.tokenizer.encode(
|
||||||
|
dataset_config.get_output(req))))
|
||||||
|
|
||||||
# lora task id
|
# lora task id
|
||||||
task_id = root_args.task_id
|
task_id = root_args.task_id
|
||||||
@ -208,30 +265,53 @@ def dataset(root_args, **kwargs):
|
|||||||
if kwargs['num_requests'] and req_cnt >= kwargs['num_requests']:
|
if kwargs['num_requests'] and req_cnt >= kwargs['num_requests']:
|
||||||
break
|
break
|
||||||
|
|
||||||
if kwargs['num_requests'] and len(input_ids) < kwargs['num_requests']:
|
if kwargs['num_requests'] and (len(input_ids) if modality is None else len(
|
||||||
|
multimodal_texts)) < kwargs['num_requests']:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"Number of requests is smaller than the num-requests user set.")
|
f"Number of requests={len(input_ids) if modality is None else len(multimodal_texts)} is"
|
||||||
|
f" smaller than the num-requests user set={kwargs['num_requests']}."
|
||||||
|
)
|
||||||
|
|
||||||
# output if randomized
|
# output if randomized
|
||||||
if kwargs['output_len_dist'] is not None:
|
if kwargs['output_len_dist'] is not None:
|
||||||
osl_mean, osl_stdev = kwargs['output_len_dist']
|
osl_mean, osl_stdev = kwargs['output_len_dist']
|
||||||
output_lens = get_norm_dist_lengths(osl_mean, osl_stdev, len(input_ids),
|
output_lens = get_norm_dist_lengths(
|
||||||
root_args.random_seed)
|
osl_mean, osl_stdev,
|
||||||
|
len(input_ids) if modality is None else len(multimodal_texts),
|
||||||
|
root_args.random_seed)
|
||||||
logging.debug(f"Input lengths: {[len(i) for i in input_ids]}")
|
logging.debug(f"Input lengths: {[len(i) for i in input_ids]}")
|
||||||
logging.debug(f"Output lengths: {output_lens}")
|
logging.debug(f"Output lengths: {output_lens}")
|
||||||
|
if modality is not None:
|
||||||
|
logging.debug(f"Modality: {modality}")
|
||||||
|
|
||||||
if not root_args.std_out:
|
if modality is not None:
|
||||||
dataset_dump(
|
if not root_args.std_out:
|
||||||
input_lens, input_ids, output_lens, task_ids, {
|
multimodal_dataset_dump(
|
||||||
"workload_type": "dataset",
|
multimodal_texts, multimodal_image_paths, output_lens, task_ids,
|
||||||
"tokenizer": root_args.tokenizer.__class__.__name__,
|
{
|
||||||
"num_requests": len(input_ids),
|
"workload_type": "dataset",
|
||||||
"max_input_len": max(input_lens),
|
"tokenizer": root_args.tokenizer.__class__.__name__,
|
||||||
"max_output_len": max(output_lens)
|
"num_requests": len(task_ids),
|
||||||
}, root_args.output)
|
"max_output_len": max(output_lens)
|
||||||
|
}, root_args.output)
|
||||||
|
else:
|
||||||
|
print_multimodal_dataset(
|
||||||
|
multimodal_texts,
|
||||||
|
multimodal_image_paths,
|
||||||
|
output_lens,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print_dataset(
|
if not root_args.std_out:
|
||||||
input_ids,
|
text_dataset_dump(
|
||||||
output_lens,
|
input_lens, input_ids, output_lens, task_ids, {
|
||||||
)
|
"workload_type": "dataset",
|
||||||
|
"tokenizer": root_args.tokenizer.__class__.__name__,
|
||||||
|
"num_requests": len(input_ids),
|
||||||
|
"max_input_len": max(input_lens),
|
||||||
|
"max_output_len": max(output_lens)
|
||||||
|
}, root_args.output)
|
||||||
|
else:
|
||||||
|
print_text_dataset(
|
||||||
|
input_ids,
|
||||||
|
output_lens,
|
||||||
|
)
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from utils.utils import (dataset_dump, gen_random_tokens, get_norm_dist_lengths,
|
from utils.utils import (gen_random_tokens, get_norm_dist_lengths,
|
||||||
get_unif_dist_lengths, print_dataset)
|
get_unif_dist_lengths, print_text_dataset,
|
||||||
|
text_dataset_dump)
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@ -57,7 +58,7 @@ def token_norm_dist(root_args, **kwargs):
|
|||||||
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
|
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
|
||||||
|
|
||||||
if not root_args.std_out:
|
if not root_args.std_out:
|
||||||
dataset_dump(
|
text_dataset_dump(
|
||||||
input_lens, input_ids, output_lens, task_ids, {
|
input_lens, input_ids, output_lens, task_ids, {
|
||||||
"workload_type": "token-norm-dist",
|
"workload_type": "token-norm-dist",
|
||||||
"input_mean": kwargs['input_mean'],
|
"input_mean": kwargs['input_mean'],
|
||||||
@ -70,7 +71,7 @@ def token_norm_dist(root_args, **kwargs):
|
|||||||
"max_output_len": max_output_len
|
"max_output_len": max_output_len
|
||||||
}, root_args.output)
|
}, root_args.output)
|
||||||
else:
|
else:
|
||||||
print_dataset(
|
print_text_dataset(
|
||||||
input_ids,
|
input_ids,
|
||||||
output_lens,
|
output_lens,
|
||||||
)
|
)
|
||||||
@ -127,7 +128,7 @@ def token_unif_dist(root_args, **kwargs):
|
|||||||
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
|
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
|
||||||
|
|
||||||
if not root_args.std_out:
|
if not root_args.std_out:
|
||||||
dataset_dump(
|
text_dataset_dump(
|
||||||
input_lens, input_ids, output_lens, task_ids, {
|
input_lens, input_ids, output_lens, task_ids, {
|
||||||
"workload_type": "token-unif-dist",
|
"workload_type": "token-unif-dist",
|
||||||
"input_min": kwargs['input_min'],
|
"input_min": kwargs['input_min'],
|
||||||
@ -140,7 +141,7 @@ def token_unif_dist(root_args, **kwargs):
|
|||||||
"max_output_len": max_output_len
|
"max_output_len": max_output_len
|
||||||
}, root_args.output)
|
}, root_args.output)
|
||||||
else:
|
else:
|
||||||
print_dataset(
|
print_text_dataset(
|
||||||
input_ids,
|
input_ids,
|
||||||
output_lens,
|
output_lens,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,22 +2,29 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class Sample(BaseModel):
|
class TextSample(BaseModel):
|
||||||
input_len: int
|
input_len: int
|
||||||
input_ids: List[int]
|
input_ids: List[int]
|
||||||
output_len: int
|
output_len: int
|
||||||
task_id: int
|
task_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalSample(BaseModel):
|
||||||
|
task_id: int
|
||||||
|
prompt: str
|
||||||
|
media_paths: List[str]
|
||||||
|
output_len: int
|
||||||
|
|
||||||
|
|
||||||
class Workload(BaseModel):
|
class Workload(BaseModel):
|
||||||
metadata: dict
|
metadata: dict
|
||||||
samples: List[Sample] = []
|
samples: List[Union[TextSample, MultimodalSample]] = []
|
||||||
|
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@ -33,22 +40,37 @@ class Workload(BaseModel):
|
|||||||
self.metadata.setdefault('workload_name', workload_name)
|
self.metadata.setdefault('workload_name', workload_name)
|
||||||
|
|
||||||
|
|
||||||
def dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
|
def text_dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
|
||||||
output_file):
|
output_file):
|
||||||
samples = []
|
samples = []
|
||||||
for i in range(len(input_ids)):
|
for i in range(len(input_ids)):
|
||||||
samples.append(
|
samples.append(
|
||||||
Sample(input_len=input_lens[i],
|
TextSample(input_len=input_lens[i],
|
||||||
input_ids=input_ids[i],
|
input_ids=input_ids[i],
|
||||||
output_len=output_lens[i],
|
output_len=output_lens[i],
|
||||||
task_id=task_ids[i]))
|
task_id=task_ids[i]))
|
||||||
workload = Workload(metadata=metadata, samples=samples)
|
workload = Workload(metadata=metadata, samples=samples)
|
||||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
with open(output_file, 'w') as f:
|
with open(output_file, 'w') as f:
|
||||||
json.dump(workload.model_dump(), f)
|
json.dump(workload.model_dump(), f)
|
||||||
|
|
||||||
|
|
||||||
def print_dataset(input_ids, output_lens):
|
def multimodal_dataset_dump(multimodal_texts, multimodal_image_paths,
|
||||||
|
output_lens, task_ids, metadata, output_file):
|
||||||
|
samples = []
|
||||||
|
for i in range(len(multimodal_texts)):
|
||||||
|
samples.append(
|
||||||
|
MultimodalSample(task_id=task_ids[i],
|
||||||
|
prompt=multimodal_texts[i],
|
||||||
|
media_paths=multimodal_image_paths[i],
|
||||||
|
output_len=output_lens[i]))
|
||||||
|
workload = Workload(metadata=metadata, samples=samples)
|
||||||
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
json.dump(workload.model_dump(), f)
|
||||||
|
|
||||||
|
|
||||||
|
def print_text_dataset(input_ids, output_lens):
|
||||||
for i, input_tokens in enumerate(input_ids):
|
for i, input_tokens in enumerate(input_ids):
|
||||||
d = {
|
d = {
|
||||||
"task_id": i,
|
"task_id": i,
|
||||||
@ -58,6 +80,19 @@ def print_dataset(input_ids, output_lens):
|
|||||||
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
|
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
def print_multimodal_dataset(multimodal_texts, multimodal_image_paths,
|
||||||
|
output_lens):
|
||||||
|
for i, (text, image_paths) in enumerate(
|
||||||
|
zip(multimodal_texts, multimodal_image_paths)):
|
||||||
|
d = {
|
||||||
|
"task_id": i,
|
||||||
|
"prompt": text,
|
||||||
|
"media_paths": image_paths,
|
||||||
|
"output_tokens": output_lens[i]
|
||||||
|
}
|
||||||
|
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
def get_list_of_delays(delay_dist, mean_time_bet_reqs, num_reqs, random_seed):
|
def get_list_of_delays(delay_dist, mean_time_bet_reqs, num_reqs, random_seed):
|
||||||
if delay_dist == "constant":
|
if delay_dist == "constant":
|
||||||
delays = [mean_time_bet_reqs] * num_reqs
|
delays = [mean_time_bet_reqs] * num_reqs
|
||||||
|
|||||||
@ -475,6 +475,115 @@ Total Latency (ms): 18563.6825
|
|||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Running multi-modal models in the PyTorch Workflow
|
||||||
|
|
||||||
|
To benchmark multi-modal models with PyTorch workflow, you can follow the similar approach as above.
|
||||||
|
|
||||||
|
First, prepare the dataset:
|
||||||
|
```
|
||||||
|
python ./benchmarks/cpp/prepare_dataset.py \
|
||||||
|
--tokenizer Qwen/Qwen2-VL-2B-Instruct \
|
||||||
|
--stdout \
|
||||||
|
dataset \
|
||||||
|
--dataset-name lmms-lab/MMMU \
|
||||||
|
--dataset-split test \
|
||||||
|
--dataset-image-key image \
|
||||||
|
--dataset-prompt-key question \
|
||||||
|
--num-requests 10 \
|
||||||
|
--output-len-dist 128,5 > mm_data.jsonl
|
||||||
|
```
|
||||||
|
It will download the media files to `/tmp` directory and prepare the dataset with their paths. Note that the `prompt` fields are texts and not tokenized ids. This is due to the fact that
|
||||||
|
the `prompt` and the media (image/video) are processed by a preprocessor for multimodal files.
|
||||||
|
|
||||||
|
Sample dataset for multimodal:
|
||||||
|
```
|
||||||
|
{"task_id":0,"prompt":"Brahma Industries sells vinyl replacement windows to home improvement retailers nationwide. The national sales manager believes that if they invest an additional $25,000 in advertising, they would increase sales volume by 10,000 units. <image 1> What is the total contribution margin?","media_paths":["/tmp/tmp9so41y3r.jpg"],"output_tokens":126}
|
||||||
|
{"task_id":1,"prompt":"Let us compute for the missing amounts under work in process inventory, what is the cost of goods manufactured? <image 1>","media_paths":["/tmp/tmpowsrb_f4.jpg"],"output_tokens":119}
|
||||||
|
{"task_id":2,"prompt":"Tsuji is reviewing the price of a 3-month Japanese yen/U.S. dollar currency futures contract, using the currency and interest rate data shown below. Because the 3-month Japanese interest rate has just increased to .50%, Itsuji recognizes that an arbitrage opportunity exists nd decides to borrow $1 million U.S. dollars to purchase Japanese yen. Calculate the yen arbitrage profit from Itsuji's strategy, using the following data: <image 1> ","media_paths":["/tmp/tmpxhdvasex.jpg"],"output_tokens":126}
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the benchmark:
|
||||||
|
```
|
||||||
|
trtllm-bench --model Qwen/Qwen2-VL-2B-Instruct \
|
||||||
|
throughput \
|
||||||
|
--dataset mm_data.jsonl \
|
||||||
|
--backend pytorch \
|
||||||
|
--num_requests 10 \
|
||||||
|
--max_batch_size 4 \
|
||||||
|
--modality image
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Sample output:
|
||||||
|
```
|
||||||
|
===========================================================
|
||||||
|
= REQUEST DETAILS
|
||||||
|
===========================================================
|
||||||
|
Number of requests: 10
|
||||||
|
Number of concurrent requests: 5.3019
|
||||||
|
Average Input Length (tokens): 411.6000
|
||||||
|
Average Output Length (tokens): 128.7000
|
||||||
|
===========================================================
|
||||||
|
= WORLD + RUNTIME INFORMATION
|
||||||
|
===========================================================
|
||||||
|
TP Size: 1
|
||||||
|
PP Size: 1
|
||||||
|
EP Size: None
|
||||||
|
Max Runtime Batch Size: 4
|
||||||
|
Max Runtime Tokens: 12288
|
||||||
|
Scheduling Policy: GUARANTEED_NO_EVICT
|
||||||
|
KV Memory Percentage: 90.00%
|
||||||
|
Issue Rate (req/sec): 1.4117E+17
|
||||||
|
|
||||||
|
===========================================================
|
||||||
|
= PERFORMANCE OVERVIEW
|
||||||
|
===========================================================
|
||||||
|
Request Throughput (req/sec): 1.4439
|
||||||
|
Total Output Throughput (tokens/sec): 185.8351
|
||||||
|
Per User Output Throughput (tokens/sec/user): 38.1959
|
||||||
|
Per GPU Output Throughput (tokens/sec/gpu): 185.8351
|
||||||
|
Total Token Throughput (tokens/sec): 780.1607
|
||||||
|
Total Latency (ms): 6925.4963
|
||||||
|
Average request latency (ms): 3671.8441
|
||||||
|
|
||||||
|
-- Request Latency Breakdown (ms) -----------------------
|
||||||
|
|
||||||
|
[Latency] P50 : 3936.3022
|
||||||
|
[Latency] P90 : 5514.4701
|
||||||
|
[Latency] P95 : 5514.4701
|
||||||
|
[Latency] P99 : 5514.4701
|
||||||
|
[Latency] MINIMUM: 2397.1047
|
||||||
|
[Latency] MAXIMUM: 5514.4701
|
||||||
|
[Latency] AVERAGE: 3671.8441
|
||||||
|
|
||||||
|
===========================================================
|
||||||
|
= DATASET DETAILS
|
||||||
|
===========================================================
|
||||||
|
Dataset Path: /workspaces/tensorrt_llm/mm_data.jsonl
|
||||||
|
Number of Sequences: 10
|
||||||
|
|
||||||
|
-- Percentiles statistics ---------------------------------
|
||||||
|
|
||||||
|
Input Output Seq. Length
|
||||||
|
-----------------------------------------------------------
|
||||||
|
MIN: 167.0000 119.0000 300.0000
|
||||||
|
MAX: 1059.0000 137.0000 1178.0000
|
||||||
|
AVG: 411.6000 128.7000 540.3000
|
||||||
|
P50: 299.0000 128.0000 427.0000
|
||||||
|
P90: 1059.0000 137.0000 1178.0000
|
||||||
|
P95: 1059.0000 137.0000 1178.0000
|
||||||
|
P99: 1059.0000 137.0000 1178.0000
|
||||||
|
===========================================================
|
||||||
|
```
|
||||||
|
|
||||||
|
**Notes and Limitations**:
|
||||||
|
- Only image datasets are supported for now.
|
||||||
|
- `--output-len-dist` is a required argument for multimodal datasets.
|
||||||
|
- Tokenizer is unused during the prepare step but it is still a required argument.
|
||||||
|
- Since the images are converted to tokens when the model is run, `trtllm-bench` uses a default large value for the maximum input sequence length when setting up the execution settings.
|
||||||
|
You can also modify the behavior by specifying a different value with the flag `--max_input_len` that suits your use-case.
|
||||||
|
|
||||||
#### Quantization in the PyTorch Flow
|
#### Quantization in the PyTorch Flow
|
||||||
|
|
||||||
In order to run a quantized run with `trtllm-bench` utilizing the PyTorch flow, you will need to use a pre-quantized
|
In order to run a quantized run with `trtllm-bench` utilizing the PyTorch flow, you will need to use a pre-quantized
|
||||||
|
|||||||
@ -17,7 +17,7 @@ def prepare_text_inputs(model_name, batch_size=8):
|
|||||||
f"HF_DATASETS_OFFLINE inside function: {datasets.config.HF_DATASETS_OFFLINE}"
|
f"HF_DATASETS_OFFLINE inside function: {datasets.config.HF_DATASETS_OFFLINE}"
|
||||||
)
|
)
|
||||||
if model_name == "BertForQuestionAnswering" or model_name == "RobertaForQuestionAnswering":
|
if model_name == "BertForQuestionAnswering" or model_name == "RobertaForQuestionAnswering":
|
||||||
squad_dataset = load_dataset("squad_v2")
|
squad_dataset = load_dataset("squad_v2", trust_remote_code=True)
|
||||||
val_dataset = squad_dataset["validation"]
|
val_dataset = squad_dataset["validation"]
|
||||||
samples = val_dataset.select(range(batch_size))
|
samples = val_dataset.select(range(batch_size))
|
||||||
|
|
||||||
@ -27,7 +27,8 @@ def prepare_text_inputs(model_name, batch_size=8):
|
|||||||
}
|
}
|
||||||
return qa_real_test_inputs
|
return qa_real_test_inputs
|
||||||
elif model_name == "BertForSequenceClassification" or model_name == "RobertaForSequenceClassification":
|
elif model_name == "BertForSequenceClassification" or model_name == "RobertaForSequenceClassification":
|
||||||
yelp_dataset = load_dataset("fancyzhx/yelp_polarity")
|
yelp_dataset = load_dataset("fancyzhx/yelp_polarity",
|
||||||
|
trust_remote_code=True)
|
||||||
val_dataset = yelp_dataset["test"]
|
val_dataset = yelp_dataset["test"]
|
||||||
samples = val_dataset.select(range(batch_size))
|
samples = val_dataset.select(range(batch_size))
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets==2.14.6
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
evaluate
|
evaluate
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
rouge_score
|
rouge_score
|
||||||
SentencePiece~=0.1.99
|
SentencePiece~=0.1.99
|
||||||
evaluate
|
evaluate
|
||||||
|
|||||||
@ -11,4 +11,4 @@ sentencepiece>=0.1.99
|
|||||||
h5py~=3.12.1
|
h5py~=3.12.1
|
||||||
rouge_score
|
rouge_score
|
||||||
nltk
|
nltk
|
||||||
datasets==2.14.6
|
datasets==3.1.0
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
protobuf
|
protobuf
|
||||||
rouge_score
|
rouge_score
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
SentencePiece>=0.1.99
|
SentencePiece>=0.1.99
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
transformers>=4.43.0
|
transformers>=4.43.0
|
||||||
datasets==2.14.6
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
|
|||||||
@ -312,7 +312,8 @@ def main(args):
|
|||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
dataset_openweb = load_dataset("stas/openwebtext-10k",
|
dataset_openweb = load_dataset("stas/openwebtext-10k",
|
||||||
cache_dir=args.dataset_path)
|
cache_dir=args.dataset_path,
|
||||||
|
trust_remote_code=True)
|
||||||
long_texts = get_long_texts(dataset_openweb) # generator
|
long_texts = get_long_texts(dataset_openweb) # generator
|
||||||
|
|
||||||
# get datapoints
|
# get datapoints
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
evaluate
|
evaluate
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
transformers>=4.39.0
|
transformers>=4.39.0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
evaluate
|
evaluate
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.15.0
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
protobuf
|
protobuf
|
||||||
rouge_score
|
rouge_score
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
protobuf
|
protobuf
|
||||||
rouge_score
|
rouge_score
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
protobuf
|
protobuf
|
||||||
rouge_score
|
rouge_score
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
tiktoken==0.6.0
|
tiktoken==0.6.0
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.6
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
datasets~=2.14.6
|
datasets==3.1.0
|
||||||
evaluate~=0.4.1
|
evaluate~=0.4.1
|
||||||
rouge_score~=0.1.2
|
rouge_score~=0.1.2
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
transformers>=4.31.0
|
transformers>=4.31.0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
rouge_score
|
rouge_score
|
||||||
evaluate
|
evaluate
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets==2.14.6
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece==0.2.0
|
sentencepiece==0.2.0
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets==2.14.5
|
datasets==3.1.0
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
evaluate
|
evaluate
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
SentencePiece>=0.1.99
|
SentencePiece>=0.1.99
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.16.1
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../../../constraints.txt
|
-c ../../../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets==2.14.6
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
|
|||||||
@ -108,6 +108,7 @@ def load_dataset(args) -> datasets.Dataset:
|
|||||||
'timeout': aiohttp.ClientTimeout(total=3600)
|
'timeout': aiohttp.ClientTimeout(total=3600)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,6 @@
|
|||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
nemo-toolkit[all]==2.0.0rc1
|
nemo-toolkit[all]==2.0.0rc1
|
||||||
megatron-core @ git+https://github.com/NVIDIA/Megatron-LM@core_r0.8.0
|
megatron-core @ git+https://github.com/NVIDIA/Megatron-LM@core_r0.8.0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
|
|||||||
@ -20,7 +20,7 @@ def create_trtllm_magpie_calibration_dataset(output_dir: str,
|
|||||||
calib_size: int = 512) -> None:
|
calib_size: int = 512) -> None:
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
dataset = load_dataset(DATASET, split="train")
|
dataset = load_dataset(DATASET, split="train", trust_remote_code=True)
|
||||||
|
|
||||||
def transform(conversation):
|
def transform(conversation):
|
||||||
value = '\n'.join(turn['value']
|
value = '\n'.join(turn['value']
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
einops~=0.7.0
|
einops~=0.7.0
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece~=0.1.99
|
sentencepiece~=0.1.99
|
||||||
evaluate
|
evaluate
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from quickstart_advanced import add_llm_args, setup_llm
|
from quickstart_advanced import add_llm_args, setup_llm
|
||||||
from transformers import AutoProcessor
|
|
||||||
|
|
||||||
from tensorrt_llm.inputs import load_image, load_video
|
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
|
||||||
|
default_video_loader)
|
||||||
|
|
||||||
example_images = [
|
example_images = [
|
||||||
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
|
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
|
||||||
@ -27,92 +28,32 @@ example_video_prompts = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def prepare_vila(args, inputs):
|
def prepare_multimodal_inputs(model_dir: str,
|
||||||
|
model_type: str,
|
||||||
|
modality: str,
|
||||||
|
prompts: List[str],
|
||||||
|
media: List[str],
|
||||||
|
image_data_format: str = "pt",
|
||||||
|
num_frames: int = 8) -> List[Dict[str, Any]]:
|
||||||
|
|
||||||
def add_media_token(prompt, multi_modal_data):
|
inputs = []
|
||||||
mm_tokens = ""
|
if modality == "image":
|
||||||
if "image" in multi_modal_data:
|
inputs = default_image_loader(prompts, media, image_data_format)
|
||||||
for _ in multi_modal_data["image"]:
|
elif modality == "video":
|
||||||
mm_tokens += "<image>"
|
inputs = default_video_loader(prompts, media, image_data_format,
|
||||||
elif "video" in multi_modal_data:
|
num_frames)
|
||||||
for _ in multi_modal_data["video"]:
|
else:
|
||||||
mm_tokens += "<vila/video>"
|
raise ValueError(f"Unsupported modality: {modality}")
|
||||||
return mm_tokens + prompt
|
|
||||||
|
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
|
||||||
|
|
||||||
for input in inputs:
|
|
||||||
input["prompt"] = add_media_token(input["prompt"],
|
|
||||||
input["multi_modal_data"])
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
def prepare_llava_next(args, inputs):
|
|
||||||
processor = AutoProcessor.from_pretrained(args.model_dir)
|
|
||||||
|
|
||||||
# Single-image inference chat template. For multi-image template,
|
|
||||||
# see https://huggingface.co/docs/transformers/en/model_doc/llava_next#multi-image-inference.
|
|
||||||
def apply_template(prompt, multimodal_data):
|
|
||||||
conversation = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image"
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
]
|
|
||||||
return processor.apply_chat_template(
|
|
||||||
conversation,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for input in inputs:
|
|
||||||
input["prompt"] = apply_template(input["prompt"],
|
|
||||||
input["multi_modal_data"])
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_qwen2_vl(args, inputs):
|
|
||||||
processor = AutoProcessor.from_pretrained(args.model_dir)
|
|
||||||
|
|
||||||
def apply_template(prompt, multimodal_data):
|
|
||||||
content = [{
|
|
||||||
"type": media_type
|
|
||||||
} for media_type, items in multimodal_data.items()
|
|
||||||
for _ in items] + [{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
}]
|
|
||||||
|
|
||||||
conversation = [{"role": "user", "content": content}]
|
|
||||||
return processor.apply_chat_template(
|
|
||||||
conversation,
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for input in inputs:
|
|
||||||
input["prompt"] = apply_template(input["prompt"],
|
|
||||||
input["multi_modal_data"])
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
|
|
||||||
MODEL_TYPE_MAP = {
|
|
||||||
"llava_llama": prepare_vila,
|
|
||||||
"llava_next": prepare_llava_next,
|
|
||||||
"qwen2_vl": prepare_qwen2_vl,
|
|
||||||
"qwen2_5_vl": prepare_qwen2_vl,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def add_multimodal_args(parser):
|
def add_multimodal_args(parser):
|
||||||
parser.add_argument("--model_type",
|
parser.add_argument("--model_type",
|
||||||
type=str,
|
type=str,
|
||||||
choices=MODEL_TYPE_MAP.keys(),
|
choices=INPUT_FORMATTER_MAP.keys(),
|
||||||
help="Model type.")
|
help="Model type.")
|
||||||
parser.add_argument("--modality",
|
parser.add_argument("--modality",
|
||||||
type=str,
|
type=str,
|
||||||
@ -150,50 +91,16 @@ def main():
|
|||||||
llm, sampling_params = setup_llm(args)
|
llm, sampling_params = setup_llm(args)
|
||||||
|
|
||||||
image_format = "pt" # ["pt", "pil"]
|
image_format = "pt" # ["pt", "pil"]
|
||||||
if args.modality == "image":
|
if args.model_type is not None:
|
||||||
prompts = args.prompt if args.prompt else example_image_prompts
|
model_type = args.model_type
|
||||||
images = args.media if args.media else example_images
|
|
||||||
if len(images) > len(prompts) and len(prompts) == 1:
|
|
||||||
# 1 prompt + N media
|
|
||||||
images = [images]
|
|
||||||
inputs = [{
|
|
||||||
"prompt": prompt,
|
|
||||||
"multi_modal_data": {
|
|
||||||
"image": [
|
|
||||||
load_image(i, format=image_format, device="cuda")
|
|
||||||
for i in image
|
|
||||||
] if isinstance(image, list) else
|
|
||||||
[load_image(image, format=image_format, device="cuda")]
|
|
||||||
}
|
|
||||||
} for prompt, image in zip(prompts, images)]
|
|
||||||
elif args.modality == "video":
|
|
||||||
prompts = args.prompt if args.prompt else example_video_prompts
|
|
||||||
videos = args.media if args.media else example_videos
|
|
||||||
if len(videos) > len(prompts) and len(prompts) == 1:
|
|
||||||
# 1 prompt + N media
|
|
||||||
videos = [videos]
|
|
||||||
inputs = [{
|
|
||||||
"prompt": prompt,
|
|
||||||
"multi_modal_data": {
|
|
||||||
"video": [
|
|
||||||
load_video(
|
|
||||||
i, args.num_frames, format=image_format, device="cuda")
|
|
||||||
for i in video
|
|
||||||
] if isinstance(video, list) else [
|
|
||||||
load_video(video,
|
|
||||||
args.num_frames,
|
|
||||||
format=image_format,
|
|
||||||
device="cuda")
|
|
||||||
]
|
|
||||||
}
|
|
||||||
} for prompt, video in zip(prompts, videos)]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported modality: {args.modality}")
|
model_type = json.load(
|
||||||
|
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
|
||||||
|
assert model_type in INPUT_FORMATTER_MAP, f"Unsupported model_type: {model_type}"
|
||||||
|
|
||||||
model_type = json.load(open(os.path.join(llm._hf_model_dir,
|
inputs = prepare_multimodal_inputs(args.model_dir, model_type,
|
||||||
'config.json')))['model_type']
|
args.modality, args.prompt, args.media,
|
||||||
assert model_type in MODEL_TYPE_MAP, f"Unsupported model_type: {model_type}"
|
image_format, args.num_frames)
|
||||||
inputs = MODEL_TYPE_MAP[model_type](args, inputs)
|
|
||||||
|
|
||||||
outputs = llm.generate(inputs, sampling_params)
|
outputs = llm.generate(inputs, sampling_params)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets>=2.14.4
|
datasets==3.1.0
|
||||||
nemo-toolkit[all]==2.0.0rc1
|
nemo-toolkit[all]==2.0.0rc1
|
||||||
rouge_score
|
rouge_score
|
||||||
transformers_stream_generator==0.0.4
|
transformers_stream_generator==0.0.4
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.16.0
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
transformers>=4.40.1
|
transformers>=4.40.1
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.dev0
|
tensorrt_llm>=0.0.dev0
|
||||||
datasets~=2.16.0
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
transformers>=4.45.0
|
transformers>=4.45.0
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.16.0
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
transformers-stream-generator
|
transformers-stream-generator
|
||||||
|
|||||||
@ -5,7 +5,7 @@ flax>=0.8.2
|
|||||||
jax~=0.4.23
|
jax~=0.4.23
|
||||||
orbax-checkpoint==0.5.7
|
orbax-checkpoint==0.5.7
|
||||||
transformers>=4.40.0
|
transformers>=4.40.0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
datasets~=2.14.5
|
datasets==3.1.0
|
||||||
rouge_score
|
rouge_score
|
||||||
sentencepiece>=0.1.99
|
sentencepiece>=0.1.99
|
||||||
evaluate
|
evaluate
|
||||||
|
|||||||
@ -110,7 +110,8 @@ def main(args):
|
|||||||
dataset = load_dataset(dataset_name,
|
dataset = load_dataset(dataset_name,
|
||||||
dataset_revision,
|
dataset_revision,
|
||||||
cache_dir=args.dataset_cache_dir,
|
cache_dir=args.dataset_cache_dir,
|
||||||
split=dataset_split)
|
split=dataset_split,
|
||||||
|
trust_remote_code=True)
|
||||||
dataset = dataset.shuffle(args.random_seed)
|
dataset = dataset.shuffle(args.random_seed)
|
||||||
|
|
||||||
max_batch_size = args.batch_size
|
max_batch_size = args.batch_size
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
-c ../constraints.txt
|
-c ../constraints.txt
|
||||||
tensorrt_llm>=0.0.0.dev0
|
tensorrt_llm>=0.0.0.dev0
|
||||||
tiktoken
|
tiktoken
|
||||||
datasets
|
datasets==3.1.0
|
||||||
kaldialign
|
kaldialign
|
||||||
openai-whisper
|
openai-whisper
|
||||||
librosa
|
librosa
|
||||||
|
|||||||
@ -564,7 +564,8 @@ if __name__ == '__main__':
|
|||||||
normalizer = EnglishTextNormalizer()
|
normalizer = EnglishTextNormalizer()
|
||||||
dataset = load_dataset(args.dataset,
|
dataset = load_dataset(args.dataset,
|
||||||
args.dataset_name,
|
args.dataset_name,
|
||||||
split=args.dataset_split)
|
split=args.dataset_split,
|
||||||
|
trust_remote_code=True)
|
||||||
if args.enable_warmup:
|
if args.enable_warmup:
|
||||||
results, total_duration = decode_dataset(
|
results, total_duration = decode_dataset(
|
||||||
model,
|
model,
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
-r requirements.txt
|
-r requirements.txt
|
||||||
datasets==2.19.2
|
|
||||||
einops
|
einops
|
||||||
graphviz
|
graphviz
|
||||||
mypy
|
mypy
|
||||||
|
|||||||
@ -31,6 +31,8 @@ pydantic>=2.9.1
|
|||||||
pillow==10.3.0
|
pillow==10.3.0
|
||||||
wheel<=0.45.1
|
wheel<=0.45.1
|
||||||
optimum
|
optimum
|
||||||
|
# evaluate needs datasets>=2.0.0 which triggers datasets>3.1.0 which is not stable: https://github.com/huggingface/datasets/issues/7467
|
||||||
|
datasets==3.1.0
|
||||||
evaluate
|
evaluate
|
||||||
mpmath>=1.3.0
|
mpmath>=1.3.0
|
||||||
click
|
click
|
||||||
|
|||||||
@ -437,10 +437,7 @@ class Qwen2VLModelBase(PreTrainedModel):
|
|||||||
inputs_embeds=input_embeds,
|
inputs_embeds=input_embeds,
|
||||||
return_context_logits=return_context_logits,
|
return_context_logits=return_context_logits,
|
||||||
mrope_config=mrope_config)
|
mrope_config=mrope_config)
|
||||||
logger.debug(
|
logger.debug(f'output shape: {output_prob.shape}')
|
||||||
f"output_ids: {(output_prob if output_prob.dim() == 2 else output_prob.unsqueeze(0)).argmax(dim=1).tolist()}"
|
|
||||||
)
|
|
||||||
logger.info(f'output shape: {output_prob.shape}')
|
|
||||||
return output_prob
|
return output_prob
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup,
|
|||||||
|
|
||||||
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
|
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
|
||||||
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
|
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
|
||||||
|
from tensorrt_llm.bench.build.build import get_model_config
|
||||||
|
|
||||||
# isort: off
|
# isort: off
|
||||||
from tensorrt_llm.bench.benchmark.utils.general import (
|
from tensorrt_llm.bench.benchmark.utils.general import (
|
||||||
@ -21,7 +22,8 @@ from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
|
|||||||
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
|
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
|
||||||
from tensorrt_llm.bench.dataclasses.reporting import ReportUtility
|
from tensorrt_llm.bench.dataclasses.reporting import ReportUtility
|
||||||
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
|
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
|
||||||
initialize_tokenizer)
|
initialize_tokenizer,
|
||||||
|
update_metadata_for_multimodal)
|
||||||
from tensorrt_llm.llmapi import LLM, CapacitySchedulerPolicy
|
from tensorrt_llm.llmapi import LLM, CapacitySchedulerPolicy
|
||||||
from tensorrt_llm.logger import logger
|
from tensorrt_llm.logger import logger
|
||||||
from tensorrt_llm.sampling_params import SamplingParams
|
from tensorrt_llm.sampling_params import SamplingParams
|
||||||
@ -92,12 +94,26 @@ from tensorrt_llm.sampling_params import SamplingParams
|
|||||||
required=False,
|
required=False,
|
||||||
help="Pass in a dataset file for parsing instead of stdin.",
|
help="Pass in a dataset file for parsing instead of stdin.",
|
||||||
)
|
)
|
||||||
|
@optgroup.option(
|
||||||
|
"--modality",
|
||||||
|
type=click.Choice(["image", "video"]),
|
||||||
|
default=None,
|
||||||
|
help="Modality of the multimodal requests.",
|
||||||
|
)
|
||||||
|
@optgroup.option(
|
||||||
|
"--max_input_len",
|
||||||
|
type=int,
|
||||||
|
default=4096,
|
||||||
|
help=
|
||||||
|
"Maximum input sequence length to use for multimodal models. This is used only when --modality "
|
||||||
|
"is specified since the actual number of vision tokens is unknown before the model is run.",
|
||||||
|
)
|
||||||
@optgroup.option(
|
@optgroup.option(
|
||||||
"--num_requests",
|
"--num_requests",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help=
|
help=
|
||||||
"Number of requests to cap benchmark run at. If not specified or set to 0, it will be the"
|
"Number of requests to cap benchmark run at. If not specified or set to 0, it will be the "
|
||||||
"length of dataset.",
|
"length of dataset.",
|
||||||
)
|
)
|
||||||
@optgroup.option(
|
@optgroup.option(
|
||||||
@ -194,6 +210,9 @@ def throughput_command(
|
|||||||
engine_dir: Path = params.pop("engine_dir")
|
engine_dir: Path = params.pop("engine_dir")
|
||||||
concurrency: int = params.pop("concurrency")
|
concurrency: int = params.pop("concurrency")
|
||||||
backend: str = params.get("backend")
|
backend: str = params.get("backend")
|
||||||
|
modality: str = params.pop("modality")
|
||||||
|
max_input_len: int = params.pop("max_input_len")
|
||||||
|
model_type = get_model_config(model, checkpoint_path).model_type
|
||||||
|
|
||||||
# Reporting options
|
# Reporting options
|
||||||
report_json: Path = params.pop("report_json")
|
report_json: Path = params.pop("report_json")
|
||||||
@ -209,15 +228,24 @@ def throughput_command(
|
|||||||
# Dataset Loading and Preparation
|
# Dataset Loading and Preparation
|
||||||
with open(dataset_path, "r") as dataset:
|
with open(dataset_path, "r") as dataset:
|
||||||
metadata, requests = create_dataset_from_stream(
|
metadata, requests = create_dataset_from_stream(
|
||||||
tokenizer, dataset, num_requests=num_requests)
|
tokenizer,
|
||||||
|
dataset,
|
||||||
|
num_requests=num_requests,
|
||||||
|
model_dir=checkpoint_path,
|
||||||
|
model_type=model_type,
|
||||||
|
modality=modality,
|
||||||
|
max_input_seq_len_for_multimodal=max_input_len)
|
||||||
metadata.dataset_path = dataset_path
|
metadata.dataset_path = dataset_path
|
||||||
params["target_input_len"] = params.get(
|
params["target_input_len"] = params.get(
|
||||||
"target_input_len") or metadata.avg_isl
|
"target_input_len") or metadata.avg_isl
|
||||||
params["target_output_len"] = params.get(
|
params["target_output_len"] = params.get(
|
||||||
"target_output_len") or metadata.avg_osl
|
"target_output_len") or metadata.avg_osl
|
||||||
|
|
||||||
# Log dataset info
|
if modality is None:
|
||||||
logger.info(metadata.get_summary_for_print())
|
# Log dataset info
|
||||||
|
# NOTE: This table is only accurate for non-multimodal models.
|
||||||
|
# The accurate table for multimodal models will be logged after the benchmark is done.
|
||||||
|
logger.info(metadata.get_summary_for_print())
|
||||||
|
|
||||||
# Engine configuration parsing
|
# Engine configuration parsing
|
||||||
if backend and backend.lower() in ["pytorch", "autodeploy"]:
|
if backend and backend.lower() in ["pytorch", "autodeploy"]:
|
||||||
@ -294,8 +322,12 @@ def throughput_command(
|
|||||||
warmup_dataset = generate_warmup_dataset(requests, warmup)
|
warmup_dataset = generate_warmup_dataset(requests, warmup)
|
||||||
logger.info("Running warmup.")
|
logger.info("Running warmup.")
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
async_benchmark(llm, sampling_params, warmup_dataset, False,
|
async_benchmark(llm,
|
||||||
concurrency))
|
sampling_params,
|
||||||
|
warmup_dataset,
|
||||||
|
False,
|
||||||
|
concurrency,
|
||||||
|
modality=modality))
|
||||||
# WAR: IterationResult is a singleton tied to the executor.
|
# WAR: IterationResult is a singleton tied to the executor.
|
||||||
# Since the benchmark calls asyncio.run() multiple times (e.g., during warmup),
|
# Since the benchmark calls asyncio.run() multiple times (e.g., during warmup),
|
||||||
# we must reset it to ensure it attaches to the correct event loop.
|
# we must reset it to ensure it attaches to the correct event loop.
|
||||||
@ -304,10 +336,19 @@ def throughput_command(
|
|||||||
|
|
||||||
with iteration_writer.capture():
|
with iteration_writer.capture():
|
||||||
statistics = asyncio.run(
|
statistics = asyncio.run(
|
||||||
async_benchmark(llm, sampling_params, requests, streaming,
|
async_benchmark(llm,
|
||||||
concurrency, iteration_writer.full_address))
|
sampling_params,
|
||||||
|
requests,
|
||||||
|
streaming,
|
||||||
|
concurrency,
|
||||||
|
iteration_writer.full_address,
|
||||||
|
modality=modality))
|
||||||
|
|
||||||
logger.info(f"Benchmark done. Reporting results...")
|
logger.info(f"Benchmark done. Reporting results...")
|
||||||
|
if modality is not None:
|
||||||
|
# For multimodal models, we need to update the metadata with the correct input lengths
|
||||||
|
metadata = update_metadata_for_multimodal(metadata, statistics)
|
||||||
|
|
||||||
report_utility = ReportUtility(statistics, metadata, runtime_config,
|
report_utility = ReportUtility(statistics, metadata, runtime_config,
|
||||||
logger, kwargs, streaming)
|
logger, kwargs, streaming)
|
||||||
if report_json:
|
if report_json:
|
||||||
|
|||||||
@ -23,7 +23,8 @@ class LlmManager:
|
|||||||
llm: LLM,
|
llm: LLM,
|
||||||
outbox: asyncio.Queue[PerfItemTuple],
|
outbox: asyncio.Queue[PerfItemTuple],
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
concurrency: int = -1) -> None:
|
concurrency: int = -1,
|
||||||
|
modality: Optional[str] = None) -> None:
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self._inbox: asyncio.Queue[Tuple[InferenceRequest,
|
self._inbox: asyncio.Queue[Tuple[InferenceRequest,
|
||||||
SamplingParams]] = asyncio.Queue()
|
SamplingParams]] = asyncio.Queue()
|
||||||
@ -38,6 +39,7 @@ class LlmManager:
|
|||||||
concurrency) if concurrency > 0 else None
|
concurrency) if concurrency > 0 else None
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
self.request_seen = asyncio.Event()
|
self.request_seen = asyncio.Event()
|
||||||
|
self.modality = modality
|
||||||
|
|
||||||
async def process_request(self, request: InferenceRequest,
|
async def process_request(self, request: InferenceRequest,
|
||||||
sampling_params: SamplingParams):
|
sampling_params: SamplingParams):
|
||||||
@ -50,7 +52,7 @@ class LlmManager:
|
|||||||
time_on_first_token = None
|
time_on_first_token = None
|
||||||
# Schedule the request in the LLM API (asynchronously)
|
# Schedule the request in the LLM API (asynchronously)
|
||||||
output: RequestOutput = self.llm.generate_async(
|
output: RequestOutput = self.llm.generate_async(
|
||||||
request.input_ids,
|
request.input_ids if self.modality is None else request.prompt,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
streaming=self.streaming)
|
streaming=self.streaming)
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
@ -70,7 +72,7 @@ class LlmManager:
|
|||||||
start_timestamp=request_start_timestamp,
|
start_timestamp=request_start_timestamp,
|
||||||
end_timestamp=response_end_timestamp,
|
end_timestamp=response_end_timestamp,
|
||||||
request_id=response.request_id,
|
request_id=response.request_id,
|
||||||
num_input_tokens=len(request.input_ids),
|
num_input_tokens=len(output.prompt_token_ids),
|
||||||
response_is_final=response.finished,
|
response_is_final=response.finished,
|
||||||
error=False,
|
error=False,
|
||||||
tokens=tokens,
|
tokens=tokens,
|
||||||
@ -201,6 +203,7 @@ async def async_benchmark(
|
|||||||
streaming: bool,
|
streaming: bool,
|
||||||
concurrency: int = -1,
|
concurrency: int = -1,
|
||||||
iteration_log_addr: str = None,
|
iteration_log_addr: str = None,
|
||||||
|
modality: Optional[str] = None,
|
||||||
) -> StatsKeeper:
|
) -> StatsKeeper:
|
||||||
outbox = asyncio.Queue()
|
outbox = asyncio.Queue()
|
||||||
statistics = StatsKeeper()
|
statistics = StatsKeeper()
|
||||||
@ -208,7 +211,11 @@ async def async_benchmark(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info("Starting benchmarking async task.")
|
logger.info("Starting benchmarking async task.")
|
||||||
backend = LlmManager(llm, outbox, streaming, concurrency=concurrency)
|
backend = LlmManager(llm,
|
||||||
|
outbox,
|
||||||
|
streaming,
|
||||||
|
concurrency=concurrency,
|
||||||
|
modality=modality)
|
||||||
backend.run(iteration_addr=iteration_log_addr)
|
backend.run(iteration_addr=iteration_log_addr)
|
||||||
|
|
||||||
enqueue_task = asyncio.create_task(
|
enqueue_task = asyncio.create_task(
|
||||||
|
|||||||
@ -116,6 +116,7 @@ class ModelConfig(BaseModel):
|
|||||||
setting calculation.
|
setting calculation.
|
||||||
"""
|
"""
|
||||||
name: str
|
name: str
|
||||||
|
model_type: str
|
||||||
param_count: int
|
param_count: int
|
||||||
num_hidden_layers: int = Field(validation_alias=AliasChoices(
|
num_hidden_layers: int = Field(validation_alias=AliasChoices(
|
||||||
"num_hidden_layers",
|
"num_hidden_layers",
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import (AliasChoices, BaseModel, Field, computed_field,
|
from pydantic import (AliasChoices, BaseModel, Field, computed_field,
|
||||||
model_validator)
|
model_validator)
|
||||||
@ -17,7 +17,7 @@ class BenchmarkEnvironment(BaseModel):
|
|||||||
|
|
||||||
class InferenceRequest(BaseModel):
|
class InferenceRequest(BaseModel):
|
||||||
task_id: int
|
task_id: int
|
||||||
prompt: Optional[str] = None
|
prompt: Optional[Union[str, Any]] = None
|
||||||
output_tokens: int
|
output_tokens: int
|
||||||
input_ids: Optional[List[int]] = Field(
|
input_ids: Optional[List[int]] = Field(
|
||||||
alias=AliasChoices("input_ids", "logits"))
|
alias=AliasChoices("input_ids", "logits"))
|
||||||
|
|||||||
@ -1,12 +1,36 @@
|
|||||||
import json
|
import json
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, TextIO, Tuple
|
from typing import Any, Dict, List, TextIO, Tuple
|
||||||
|
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata,
|
from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata,
|
||||||
InferenceRequest)
|
InferenceRequest)
|
||||||
from tensorrt_llm.bench.dataclasses.statistics import PercentileStats
|
from tensorrt_llm.bench.dataclasses.statistics import PercentileStats
|
||||||
|
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
|
||||||
|
default_video_loader)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_multimodal_inputs(model_dir: str,
|
||||||
|
model_type: str,
|
||||||
|
modality: str,
|
||||||
|
prompts: List[str],
|
||||||
|
media: List[str],
|
||||||
|
image_data_format: str = "pt",
|
||||||
|
num_frames: int = 8) -> List[Dict[str, Any]]:
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
if modality == "image":
|
||||||
|
inputs = default_image_loader(prompts, media, image_data_format)
|
||||||
|
elif modality == "video":
|
||||||
|
inputs = default_video_loader(prompts, media, image_data_format,
|
||||||
|
num_frames)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported modality: {modality}")
|
||||||
|
|
||||||
|
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer:
|
def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer:
|
||||||
@ -36,6 +60,10 @@ def create_dataset_from_stream(
|
|||||||
max_input_length: int = 0,
|
max_input_length: int = 0,
|
||||||
max_output_length: int = 0,
|
max_output_length: int = 0,
|
||||||
num_requests: int = 0,
|
num_requests: int = 0,
|
||||||
|
model_dir: str = None,
|
||||||
|
model_type: str = None,
|
||||||
|
modality: str = None,
|
||||||
|
max_input_seq_len_for_multimodal: int = 4096,
|
||||||
) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
|
) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
|
||||||
"""Generate metadata and a list of requests to drive benchmarking.
|
"""Generate metadata and a list of requests to drive benchmarking.
|
||||||
|
|
||||||
@ -83,13 +111,30 @@ def create_dataset_from_stream(
|
|||||||
# Each line should be a complete JSON dictionary with no indentation
|
# Each line should be a complete JSON dictionary with no indentation
|
||||||
# or newline characters.
|
# or newline characters.
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
logits = data.get("input_ids", data.get("logits", None))
|
if modality is not None:
|
||||||
prompt = data.get("prompt", None)
|
# Multimodal data
|
||||||
|
assert modality in [
|
||||||
|
"image", "video"
|
||||||
|
], f"Modality must be one of ['image', 'video'] but got {modality}."
|
||||||
|
|
||||||
|
prompt = data.get("prompt") # cannot be None
|
||||||
|
media_paths = data.get("media_paths", None)
|
||||||
|
inputs = prepare_multimodal_inputs(
|
||||||
|
model_dir,
|
||||||
|
model_type,
|
||||||
|
modality,
|
||||||
|
prompts=[prompt],
|
||||||
|
media=media_paths) # list of dicts
|
||||||
|
logits = None # cannot tokenize multi-modal data, handled by preprocessor
|
||||||
|
prompt = inputs[0]
|
||||||
|
else:
|
||||||
|
logits = data.get("input_ids", data.get("logits", None))
|
||||||
|
prompt = data.get("prompt", None)
|
||||||
|
# If the request comes in with logits, just use the provided.
|
||||||
|
# Otherwise we need to tokenize it.
|
||||||
|
logits = tokenize(prompt)["input_ids"] if logits is None else logits
|
||||||
task_id = data["task_id"]
|
task_id = data["task_id"]
|
||||||
osl = data["output_tokens"]
|
osl = data["output_tokens"]
|
||||||
# If the request comes in with logits, just use the provided.
|
|
||||||
# Otherwise we need to tokenize it.
|
|
||||||
logits = tokenize(prompt)["input_ids"] if logits is None else logits
|
|
||||||
|
|
||||||
request = InferenceRequest(
|
request = InferenceRequest(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -97,9 +142,14 @@ def create_dataset_from_stream(
|
|||||||
output_tokens=output_limiter(osl),
|
output_tokens=output_limiter(osl),
|
||||||
input_ids=logits,
|
input_ids=logits,
|
||||||
)
|
)
|
||||||
all_isl.append(len(logits))
|
|
||||||
all_osl.append(osl)
|
all_osl.append(osl)
|
||||||
all_seq_len.append(len(logits) + osl)
|
if modality is not None:
|
||||||
|
cur_isl = max_input_seq_len_for_multimodal # NOTE: actual sequence length is unknown until the model is run
|
||||||
|
all_isl.append(cur_isl)
|
||||||
|
all_seq_len.append(cur_isl + osl)
|
||||||
|
else:
|
||||||
|
all_isl.append(len(logits))
|
||||||
|
all_seq_len.append(len(logits) + osl)
|
||||||
dataset.append(request)
|
dataset.append(request)
|
||||||
|
|
||||||
isl_stats = PercentileStats.from_iterable(all_isl)
|
isl_stats = PercentileStats.from_iterable(all_isl)
|
||||||
@ -115,3 +165,31 @@ def create_dataset_from_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return metadata, dataset
|
return metadata, dataset
|
||||||
|
|
||||||
|
|
||||||
|
def update_metadata_for_multimodal(metadata, statistics) -> DatasetMetadata:
|
||||||
|
"""Update the metadata from benchmark statistics. Only used for multimodal models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (DatasetMetadata): The metadata to update.
|
||||||
|
statistics (StatsKeeper): The statistics to update the metadata with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatasetMetadata: The updated metadata.
|
||||||
|
"""
|
||||||
|
all_isl = []
|
||||||
|
all_osl = []
|
||||||
|
all_seq_len = []
|
||||||
|
for request in statistics.requests.values():
|
||||||
|
all_isl.append(request.num_input_tokens)
|
||||||
|
all_osl.append(request.num_total_output_tokens)
|
||||||
|
all_seq_len.append(request.num_input_tokens +
|
||||||
|
request.num_total_output_tokens)
|
||||||
|
isl_stats = PercentileStats.from_iterable(all_isl)
|
||||||
|
osl_stats = PercentileStats.from_iterable(all_osl)
|
||||||
|
seq_len_stats = PercentileStats.from_iterable(all_seq_len)
|
||||||
|
metadata.isl_stats = isl_stats
|
||||||
|
metadata.osl_stats = osl_stats
|
||||||
|
metadata.seq_len_stats = seq_len_stats
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|||||||
@ -36,7 +36,10 @@ class CnnDailymail(Evaluator):
|
|||||||
system_prompt: Optional[str] = None):
|
system_prompt: Optional[str] = None):
|
||||||
super().__init__(apply_chat_template=apply_chat_template,
|
super().__init__(apply_chat_template=apply_chat_template,
|
||||||
system_prompt=system_prompt)
|
system_prompt=system_prompt)
|
||||||
self.data = datasets.load_dataset(dataset_path, "3.0.0", split="test")
|
self.data = datasets.load_dataset(dataset_path,
|
||||||
|
"3.0.0",
|
||||||
|
split="test",
|
||||||
|
trust_remote_code=True)
|
||||||
self.data = self.data.shuffle(random_seed)
|
self.data = self.data.shuffle(random_seed)
|
||||||
if num_samples is None:
|
if num_samples is None:
|
||||||
self.num_samples = self.data.num_rows
|
self.num_samples = self.data.num_rows
|
||||||
|
|||||||
@ -1,10 +1,15 @@
|
|||||||
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
|
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
|
||||||
from .registry import (ExtraProcessedInputs, InputProcessor,
|
from .registry import (ExtraProcessedInputs, InputProcessor,
|
||||||
create_input_processor, register_input_processor)
|
create_input_processor, register_input_processor)
|
||||||
from .utils import load_image, load_video
|
from .utils import (INPUT_FORMATTER_MAP, default_image_loader,
|
||||||
|
default_video_loader, format_llava_next_input,
|
||||||
|
format_qwen2_vl_input, format_vila_input, load_image,
|
||||||
|
load_video)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PromptInputs", "prompt_inputs", "TextPrompt", "TokensPrompt",
|
"PromptInputs", "prompt_inputs", "TextPrompt", "TokensPrompt",
|
||||||
"InputProcessor", "create_input_processor", "register_input_processor",
|
"InputProcessor", "create_input_processor", "register_input_processor",
|
||||||
"ExtraProcessedInputs", "load_image", "load_video"
|
"ExtraProcessedInputs", "load_image", "load_video", "INPUT_FORMATTER_MAP",
|
||||||
|
"default_image_loader", "default_video_loader", "format_vila_input",
|
||||||
|
"format_llava_next_input", "format_qwen2_vl_input"
|
||||||
]
|
]
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import ToTensor
|
from torchvision.transforms import ToTensor
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
|
||||||
def load_image(image: str,
|
def load_image(image: str,
|
||||||
@ -67,3 +68,158 @@ def load_video(
|
|||||||
device=device) if format == "pt" else frames[index]
|
device=device) if format == "pt" else frames[index]
|
||||||
for index in indices if index in frames
|
for index in indices if index in frames
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
VLM input preparation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def format_vila_input(model_dir, inputs):
|
||||||
|
"""
|
||||||
|
This function formats the input for the VILA/NVILA VL model.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model_dir: The directory of the model to load any preprocessor.
|
||||||
|
inputs: The list of inputs to format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def add_media_token(prompt, multi_modal_data):
|
||||||
|
mm_tokens = ""
|
||||||
|
if "image" in multi_modal_data:
|
||||||
|
for _ in multi_modal_data["image"]:
|
||||||
|
mm_tokens += "<image>"
|
||||||
|
elif "video" in multi_modal_data:
|
||||||
|
for _ in multi_modal_data["video"]:
|
||||||
|
mm_tokens += "<vila/video>"
|
||||||
|
return mm_tokens + prompt
|
||||||
|
|
||||||
|
for input in inputs:
|
||||||
|
input["prompt"] = add_media_token(input["prompt"],
|
||||||
|
input["multi_modal_data"])
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def format_llava_next_input(model_dir, inputs):
|
||||||
|
"""
|
||||||
|
This function formats the input for the Llava Next VL model.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model_dir: The directory of the model to load any preprocessor.
|
||||||
|
inputs: The list of inputs to format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
|
||||||
|
"""
|
||||||
|
processor = AutoProcessor.from_pretrained(model_dir)
|
||||||
|
|
||||||
|
# Single-image inference chat template. For multi-image template,
|
||||||
|
# see https://huggingface.co/docs/transformers/en/model_doc/llava_next#multi-image-inference.
|
||||||
|
def apply_template(prompt, multimodal_data):
|
||||||
|
conversation = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return processor.apply_chat_template(
|
||||||
|
conversation,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for input in inputs:
|
||||||
|
input["prompt"] = apply_template(input["prompt"],
|
||||||
|
input["multi_modal_data"])
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def format_qwen2_vl_input(model_dir, inputs):
|
||||||
|
"""
|
||||||
|
This function formats the input for the Qwen2/Qwen2.5 VL model.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
model_dir: The directory of the model to load any preprocessor.
|
||||||
|
inputs: The list of inputs to format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
|
||||||
|
"""
|
||||||
|
processor = AutoProcessor.from_pretrained(model_dir)
|
||||||
|
|
||||||
|
def apply_template(prompt, multimodal_data):
|
||||||
|
content = [{
|
||||||
|
"type": media_type
|
||||||
|
} for media_type, items in multimodal_data.items()
|
||||||
|
for _ in items] + [{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt
|
||||||
|
}]
|
||||||
|
|
||||||
|
conversation = [{"role": "user", "content": content}]
|
||||||
|
# print(conversation)
|
||||||
|
return processor.apply_chat_template(
|
||||||
|
conversation,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for input in inputs:
|
||||||
|
input["prompt"] = apply_template(input["prompt"],
|
||||||
|
input["multi_modal_data"])
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def default_image_loader(prompts, images, image_data_format="pt"):
|
||||||
|
if len(images) > len(prompts) and len(prompts) == 1:
|
||||||
|
# 1 prompt + N media
|
||||||
|
images = [images]
|
||||||
|
inputs = [{
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"image": [
|
||||||
|
load_image(i, format=image_data_format, device="cuda")
|
||||||
|
for i in image
|
||||||
|
] if isinstance(image, list) else
|
||||||
|
[load_image(image, format=image_data_format, device="cuda")]
|
||||||
|
}
|
||||||
|
} for prompt, image in zip(prompts, images)]
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def default_video_loader(prompts, videos, image_data_format="pt", num_frames=8):
|
||||||
|
if len(videos) > len(prompts) and len(prompts) == 1:
|
||||||
|
# 1 prompt + N media
|
||||||
|
videos = [videos]
|
||||||
|
inputs = [{
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"video": [
|
||||||
|
load_video(
|
||||||
|
i, num_frames, format=image_data_format, device="cuda")
|
||||||
|
for i in video
|
||||||
|
] if isinstance(video, list) else [
|
||||||
|
load_video(
|
||||||
|
video, num_frames, format=image_data_format, device="cuda")
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} for prompt, video in zip(prompts, videos)]
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
INPUT_FORMATTER_MAP = {
|
||||||
|
"llava_llama": format_vila_input,
|
||||||
|
"llava_next": format_llava_next_input,
|
||||||
|
"qwen2_vl": format_qwen2_vl_input,
|
||||||
|
"qwen2_5_vl": format_qwen2_vl_input,
|
||||||
|
}
|
||||||
|
|||||||
@ -306,6 +306,7 @@ def load_calib_dataset(dataset_name_or_dir: str,
|
|||||||
dataset = load_dataset(dataset_name_or_dir,
|
dataset = load_dataset(dataset_name_or_dir,
|
||||||
name=config_name,
|
name=config_name,
|
||||||
split=split,
|
split=split,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
return dataset[key]
|
return dataset[key]
|
||||||
|
|
||||||
|
|||||||
@ -384,20 +384,26 @@ def get_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
|
|||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"json",
|
"json",
|
||||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||||
split="train")
|
split="train",
|
||||||
|
trust_remote_code=True)
|
||||||
dataset = dataset["text"][:calib_size]
|
dataset = dataset["text"][:calib_size]
|
||||||
elif "scienceqa" in dataset_name_or_dir.lower(
|
elif "scienceqa" in dataset_name_or_dir.lower(
|
||||||
) or "science_qa" in dataset_name_or_dir.lower():
|
) or "science_qa" in dataset_name_or_dir.lower():
|
||||||
if os.path.isdir(dataset_name_or_dir):
|
if os.path.isdir(dataset_name_or_dir):
|
||||||
dataset = load_dataset(dataset_name_or_dir, split="train")
|
dataset = load_dataset(dataset_name_or_dir,
|
||||||
|
split="train",
|
||||||
|
trust_remote_code=True)
|
||||||
else:
|
else:
|
||||||
dataset = load_dataset("derek-thomas/ScienceQA", split="train")
|
dataset = load_dataset("derek-thomas/ScienceQA",
|
||||||
|
split="train",
|
||||||
|
trust_remote_code=True)
|
||||||
dataset = dataset.select(range(calib_size))
|
dataset = dataset.select(range(calib_size))
|
||||||
elif "cnn_dailymail" in dataset_name_or_dir:
|
elif "cnn_dailymail" in dataset_name_or_dir:
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
dataset_name_or_dir,
|
dataset_name_or_dir,
|
||||||
name="3.0.0",
|
name="3.0.0",
|
||||||
split="train",
|
split="train",
|
||||||
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
dataset = dataset["article"][:calib_size]
|
dataset = dataset["article"][:calib_size]
|
||||||
elif os.path.isdir(dataset_name_or_dir):
|
elif os.path.isdir(dataset_name_or_dir):
|
||||||
@ -405,7 +411,9 @@ def get_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
|
|||||||
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
|
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
|
||||||
"assuming the calibration data are in the train split and text column."
|
"assuming the calibration data are in the train split and text column."
|
||||||
)
|
)
|
||||||
dataset = load_dataset(dataset_name_or_dir, split="train")
|
dataset = load_dataset(dataset_name_or_dir,
|
||||||
|
split="train",
|
||||||
|
trust_remote_code=True)
|
||||||
dataset = dataset["text"][:calib_size]
|
dataset = dataset["text"][:calib_size]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -993,22 +1001,29 @@ def get_nemo_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
|
|||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"json",
|
"json",
|
||||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||||
split="train")
|
split="train",
|
||||||
|
trust_remote_code=True)
|
||||||
text_column = "text"
|
text_column = "text"
|
||||||
elif "wikitext" in dataset_name_or_dir:
|
elif "wikitext" in dataset_name_or_dir:
|
||||||
dataset = load_dataset(dataset_name_or_dir,
|
dataset = load_dataset(dataset_name_or_dir,
|
||||||
"wikitext-103-v1",
|
"wikitext-103-v1",
|
||||||
split="train")
|
split="train",
|
||||||
|
trust_remote_code=True)
|
||||||
text_column = "text"
|
text_column = "text"
|
||||||
elif "cnn_dailymail" in dataset_name_or_dir:
|
elif "cnn_dailymail" in dataset_name_or_dir:
|
||||||
dataset = load_dataset(dataset_name_or_dir, name="3.0.0", split="train")
|
dataset = load_dataset(dataset_name_or_dir,
|
||||||
|
name="3.0.0",
|
||||||
|
split="train",
|
||||||
|
trust_remote_code=True)
|
||||||
text_column = "article"
|
text_column = "article"
|
||||||
elif os.path.isdir(dataset_name_or_dir):
|
elif os.path.isdir(dataset_name_or_dir):
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
|
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
|
||||||
"assuming the calibration data are in the train split and text column."
|
"assuming the calibration data are in the train split and text column."
|
||||||
)
|
)
|
||||||
dataset = load_dataset(dataset_name_or_dir, split="train")
|
dataset = load_dataset(dataset_name_or_dir,
|
||||||
|
split="train",
|
||||||
|
trust_remote_code=True)
|
||||||
text_column = "text"
|
text_column = "text"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@ -366,8 +366,11 @@ def test_tokenizer_decode_incrementally(tokenizer_dir: str, threshold: float):
|
|||||||
num_samples = 100
|
num_samples = 100
|
||||||
cnn_dailymail = datasets.load_dataset(cnn_dailymail_path,
|
cnn_dailymail = datasets.load_dataset(cnn_dailymail_path,
|
||||||
name='3.0.0',
|
name='3.0.0',
|
||||||
split='train')
|
split='train',
|
||||||
alpaca_chinese = datasets.load_dataset(alpaca_chinese_path, split='train')
|
trust_remote_code=True)
|
||||||
|
alpaca_chinese = datasets.load_dataset(alpaca_chinese_path,
|
||||||
|
split='train',
|
||||||
|
trust_remote_code=True)
|
||||||
dataset = cnn_dailymail['article'][:num_samples // 2] + alpaca_chinese[
|
dataset = cnn_dailymail['article'][:num_samples // 2] + alpaca_chinese[
|
||||||
'output_zh'][:num_samples // 2]
|
'output_zh'][:num_samples // 2]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user