feat: adding multimodal (only image for now) support in trtllm-bench (#3490)

* feat: adding multimodal (only image for now) support in trtllm-bench

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* fix: add  in load_dataset() calls to maintain the v2.19.2 behavior

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* re-adding prompt_token_ids and using that for prompt_len

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* updating the datasets version in examples as well

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* api changes are not needed

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* moving datasets requirement and removing a missed api change

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* addressing review comments

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* refactoring the quickstart example

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

---------

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
rakib-hasan 2025-04-17 16:06:16 -07:00 committed by GitHub
parent 26ebd95302
commit ff3b741045
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
63 changed files with 703 additions and 258 deletions

View File

@ -1,12 +1,16 @@
import logging
import random
import re
import tempfile
from typing import Optional
import click
from datasets import load_dataset
from PIL import Image
from pydantic import BaseModel, model_validator
from utils.utils import dataset_dump, get_norm_dist_lengths, print_dataset
from utils.utils import (get_norm_dist_lengths, multimodal_dataset_dump,
print_multimodal_dataset, print_text_dataset,
text_dataset_dump)
def validate_output_len_dist(ctx, param, value):
@ -31,8 +35,10 @@ class DatasetConfig(BaseModel):
"""Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits."""
split: Optional[str]
"""The dataset dictionary used for the input sentence."""
input_key: str
input_key: Optional[str] = None
"""The dataset dictionary key used for the prompt of the input sentence. Must not be set when prompt is set."""
image_key: Optional[str] = None
"""The dataset dictionary key used for the images."""
prompt_key: Optional[str] = None
"""The prompt sentence to be added to the input sentence. Must not be set when prompt_key is set."""
prompt: Optional[str] = None
@ -75,6 +81,20 @@ class DatasetConfig(BaseModel):
f"{req.keys()}")
return req[self.input_key]
def get_images(self, req):
"""Get the images from the given request."""
image_keys = [self.image_key
] + [f"{self.image_key}_{i}" for i in range(1, 8)]
assert any(key in req for key in image_keys), (
f"Dataset {self.name} does not have key '{self.image_key}'. "
"Please set --dataset-image-key to one of the available keys: "
f"{req.keys()}")
images = []
for key in image_keys:
if key in req and req[key] is not None:
images.append(req[key])
return images
def get_output(self, req):
"""Get the output sentence from the given request."""
if self.output_key is None:
@ -105,7 +125,8 @@ def load_dataset_from_hf(dataset_config: DatasetConfig):
dataset = iter(
load_dataset(*dataset_config.query,
split=dataset_config.split,
streaming=True))
streaming=True,
trust_remote_code=True))
except ValueError as e:
if "Config" in e:
e += "\n Please add the config name to the dataset config yaml."
@ -130,9 +151,12 @@ def load_dataset_from_hf(dataset_config: DatasetConfig):
required=True,
help=f"Split of the dataset to use.")
@click.option("--dataset-input-key",
required=True,
type=str,
help=f"The dataset dictionary key for input.")
@click.option("--dataset-image-key",
type=str,
default="image",
help=f"The dataset dictionary key for images.")
@click.option("--dataset-prompt-key",
type=str,
default=None,
@ -181,8 +205,39 @@ def dataset(root_args, **kwargs):
output_lens = []
task_ids = []
req_cnt = 0
modality = None
multimodal_texts = []
multimodal_image_paths = []
for req in load_dataset_from_hf(dataset_config):
# input
if any(key in req for key in ['image', 'image_1', 'video']):
# multimodal input
if 'video' in req and req['video'] is not None:
assert "Not supported yet"
assert kwargs['output_len_dist'] is not None, (
"Output length distribution must be set for multimodal requests."
)
modality = 'image'
text = dataset_config.get_prompt(req)
images = dataset_config.get_images(req)
image_paths = []
for image in images:
if image is not None:
if isinstance(image, str):
image_paths.append(image)
elif isinstance(image, Image.Image):
with tempfile.NamedTemporaryFile(
suffix=".jpg", delete=False) as tmp_file:
logging.debug(f"Saving image to {tmp_file.name}")
image = image.convert("RGB")
image.save(tmp_file, "JPEG")
filepath = tmp_file.name
image_paths.append(filepath)
else:
raise ValueError(f"Invalid image path: {image}")
multimodal_texts.append(text)
multimodal_image_paths.append(image_paths)
else:
# text input
prompt = dataset_config.get_prompt(
req) + ' ' + dataset_config.get_input(req)
logging.debug(f"Input sequence: {prompt}")
@ -195,7 +250,9 @@ def dataset(root_args, **kwargs):
# output if fetch from golden
if kwargs['output_len_dist'] is None:
output_lens.append(
len(root_args.tokenizer.encode(dataset_config.get_output(req))))
len(
root_args.tokenizer.encode(
dataset_config.get_output(req))))
# lora task id
task_id = root_args.task_id
@ -208,21 +265,44 @@ def dataset(root_args, **kwargs):
if kwargs['num_requests'] and req_cnt >= kwargs['num_requests']:
break
if kwargs['num_requests'] and len(input_ids) < kwargs['num_requests']:
if kwargs['num_requests'] and (len(input_ids) if modality is None else len(
multimodal_texts)) < kwargs['num_requests']:
logging.warning(
"Number of requests is smaller than the num-requests user set.")
f"Number of requests={len(input_ids) if modality is None else len(multimodal_texts)} is"
f" smaller than the num-requests user set={kwargs['num_requests']}."
)
# output if randomized
if kwargs['output_len_dist'] is not None:
osl_mean, osl_stdev = kwargs['output_len_dist']
output_lens = get_norm_dist_lengths(osl_mean, osl_stdev, len(input_ids),
output_lens = get_norm_dist_lengths(
osl_mean, osl_stdev,
len(input_ids) if modality is None else len(multimodal_texts),
root_args.random_seed)
logging.debug(f"Input lengths: {[len(i) for i in input_ids]}")
logging.debug(f"Output lengths: {output_lens}")
if modality is not None:
logging.debug(f"Modality: {modality}")
if modality is not None:
if not root_args.std_out:
dataset_dump(
multimodal_dataset_dump(
multimodal_texts, multimodal_image_paths, output_lens, task_ids,
{
"workload_type": "dataset",
"tokenizer": root_args.tokenizer.__class__.__name__,
"num_requests": len(task_ids),
"max_output_len": max(output_lens)
}, root_args.output)
else:
print_multimodal_dataset(
multimodal_texts,
multimodal_image_paths,
output_lens,
)
else:
if not root_args.std_out:
text_dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "dataset",
"tokenizer": root_args.tokenizer.__class__.__name__,
@ -231,7 +311,7 @@ def dataset(root_args, **kwargs):
"max_output_len": max(output_lens)
}, root_args.output)
else:
print_dataset(
print_text_dataset(
input_ids,
output_lens,
)

View File

@ -1,8 +1,9 @@
import random
import click
from utils.utils import (dataset_dump, gen_random_tokens, get_norm_dist_lengths,
get_unif_dist_lengths, print_dataset)
from utils.utils import (gen_random_tokens, get_norm_dist_lengths,
get_unif_dist_lengths, print_text_dataset,
text_dataset_dump)
@click.command()
@ -57,7 +58,7 @@ def token_norm_dist(root_args, **kwargs):
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
if not root_args.std_out:
dataset_dump(
text_dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "token-norm-dist",
"input_mean": kwargs['input_mean'],
@ -70,7 +71,7 @@ def token_norm_dist(root_args, **kwargs):
"max_output_len": max_output_len
}, root_args.output)
else:
print_dataset(
print_text_dataset(
input_ids,
output_lens,
)
@ -127,7 +128,7 @@ def token_unif_dist(root_args, **kwargs):
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
if not root_args.std_out:
dataset_dump(
text_dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "token-unif-dist",
"input_min": kwargs['input_min'],
@ -140,7 +141,7 @@ def token_unif_dist(root_args, **kwargs):
"max_output_len": max_output_len
}, root_args.output)
else:
print_dataset(
print_text_dataset(
input_ids,
output_lens,
)

View File

@ -2,22 +2,29 @@ import json
import math
import os
import random
from typing import List
from typing import List, Union
import numpy as np
from pydantic import BaseModel
class Sample(BaseModel):
class TextSample(BaseModel):
input_len: int
input_ids: List[int]
output_len: int
task_id: int
class MultimodalSample(BaseModel):
task_id: int
prompt: str
media_paths: List[str]
output_len: int
class Workload(BaseModel):
metadata: dict
samples: List[Sample] = []
samples: List[Union[TextSample, MultimodalSample]] = []
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
@ -33,12 +40,12 @@ class Workload(BaseModel):
self.metadata.setdefault('workload_name', workload_name)
def dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
def text_dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
output_file):
samples = []
for i in range(len(input_ids)):
samples.append(
Sample(input_len=input_lens[i],
TextSample(input_len=input_lens[i],
input_ids=input_ids[i],
output_len=output_lens[i],
task_id=task_ids[i]))
@ -48,7 +55,22 @@ def dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
json.dump(workload.model_dump(), f)
def print_dataset(input_ids, output_lens):
def multimodal_dataset_dump(multimodal_texts, multimodal_image_paths,
output_lens, task_ids, metadata, output_file):
samples = []
for i in range(len(multimodal_texts)):
samples.append(
MultimodalSample(task_id=task_ids[i],
prompt=multimodal_texts[i],
media_paths=multimodal_image_paths[i],
output_len=output_lens[i]))
workload = Workload(metadata=metadata, samples=samples)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
json.dump(workload.model_dump(), f)
def print_text_dataset(input_ids, output_lens):
for i, input_tokens in enumerate(input_ids):
d = {
"task_id": i,
@ -58,6 +80,19 @@ def print_dataset(input_ids, output_lens):
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
def print_multimodal_dataset(multimodal_texts, multimodal_image_paths,
output_lens):
for i, (text, image_paths) in enumerate(
zip(multimodal_texts, multimodal_image_paths)):
d = {
"task_id": i,
"prompt": text,
"media_paths": image_paths,
"output_tokens": output_lens[i]
}
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
def get_list_of_delays(delay_dist, mean_time_bet_reqs, num_reqs, random_seed):
if delay_dist == "constant":
delays = [mean_time_bet_reqs] * num_reqs

View File

@ -475,6 +475,115 @@ Total Latency (ms): 18563.6825
```
#### Running multi-modal models in the PyTorch Workflow
To benchmark multi-modal models with PyTorch workflow, you can follow the similar approach as above.
First, prepare the dataset:
```
python ./benchmarks/cpp/prepare_dataset.py \
--tokenizer Qwen/Qwen2-VL-2B-Instruct \
--stdout \
dataset \
--dataset-name lmms-lab/MMMU \
--dataset-split test \
--dataset-image-key image \
--dataset-prompt-key question \
--num-requests 10 \
--output-len-dist 128,5 > mm_data.jsonl
```
It will download the media files to `/tmp` directory and prepare the dataset with their paths. Note that the `prompt` fields are texts and not tokenized ids. This is due to the fact that
the `prompt` and the media (image/video) are processed by a preprocessor for multimodal files.
Sample dataset for multimodal:
```
{"task_id":0,"prompt":"Brahma Industries sells vinyl replacement windows to home improvement retailers nationwide. The national sales manager believes that if they invest an additional $25,000 in advertising, they would increase sales volume by 10,000 units. <image 1> What is the total contribution margin?","media_paths":["/tmp/tmp9so41y3r.jpg"],"output_tokens":126}
{"task_id":1,"prompt":"Let us compute for the missing amounts under work in process inventory, what is the cost of goods manufactured? <image 1>","media_paths":["/tmp/tmpowsrb_f4.jpg"],"output_tokens":119}
{"task_id":2,"prompt":"Tsuji is reviewing the price of a 3-month Japanese yen/U.S. dollar currency futures contract, using the currency and interest rate data shown below. Because the 3-month Japanese interest rate has just increased to .50%, Itsuji recognizes that an arbitrage opportunity exists nd decides to borrow $1 million U.S. dollars to purchase Japanese yen. Calculate the yen arbitrage profit from Itsuji's strategy, using the following data: <image 1> ","media_paths":["/tmp/tmpxhdvasex.jpg"],"output_tokens":126}
...
```
Run the benchmark:
```
trtllm-bench --model Qwen/Qwen2-VL-2B-Instruct \
throughput \
--dataset mm_data.jsonl \
--backend pytorch \
--num_requests 10 \
--max_batch_size 4 \
--modality image
```
Sample output:
```
===========================================================
= REQUEST DETAILS
===========================================================
Number of requests: 10
Number of concurrent requests: 5.3019
Average Input Length (tokens): 411.6000
Average Output Length (tokens): 128.7000
===========================================================
= WORLD + RUNTIME INFORMATION
===========================================================
TP Size: 1
PP Size: 1
EP Size: None
Max Runtime Batch Size: 4
Max Runtime Tokens: 12288
Scheduling Policy: GUARANTEED_NO_EVICT
KV Memory Percentage: 90.00%
Issue Rate (req/sec): 1.4117E+17
===========================================================
= PERFORMANCE OVERVIEW
===========================================================
Request Throughput (req/sec): 1.4439
Total Output Throughput (tokens/sec): 185.8351
Per User Output Throughput (tokens/sec/user): 38.1959
Per GPU Output Throughput (tokens/sec/gpu): 185.8351
Total Token Throughput (tokens/sec): 780.1607
Total Latency (ms): 6925.4963
Average request latency (ms): 3671.8441
-- Request Latency Breakdown (ms) -----------------------
[Latency] P50 : 3936.3022
[Latency] P90 : 5514.4701
[Latency] P95 : 5514.4701
[Latency] P99 : 5514.4701
[Latency] MINIMUM: 2397.1047
[Latency] MAXIMUM: 5514.4701
[Latency] AVERAGE: 3671.8441
===========================================================
= DATASET DETAILS
===========================================================
Dataset Path: /workspaces/tensorrt_llm/mm_data.jsonl
Number of Sequences: 10
-- Percentiles statistics ---------------------------------
Input Output Seq. Length
-----------------------------------------------------------
MIN: 167.0000 119.0000 300.0000
MAX: 1059.0000 137.0000 1178.0000
AVG: 411.6000 128.7000 540.3000
P50: 299.0000 128.0000 427.0000
P90: 1059.0000 137.0000 1178.0000
P95: 1059.0000 137.0000 1178.0000
P99: 1059.0000 137.0000 1178.0000
===========================================================
```
**Notes and Limitations**:
- Only image datasets are supported for now.
- `--output-len-dist` is a required argument for multimodal datasets.
- Tokenizer is unused during the prepare step but it is still a required argument.
- Since the images are converted to tokens when the model is run, `trtllm-bench` uses a default large value for the maximum input sequence length when setting up the execution settings.
You can also modify the behavior by specifying a different value with the flag `--max_input_len` that suits your use-case.
#### Quantization in the PyTorch Flow
In order to run a quantized run with `trtllm-bench` utilizing the PyTorch flow, you will need to use a pre-quantized

View File

@ -17,7 +17,7 @@ def prepare_text_inputs(model_name, batch_size=8):
f"HF_DATASETS_OFFLINE inside function: {datasets.config.HF_DATASETS_OFFLINE}"
)
if model_name == "BertForQuestionAnswering" or model_name == "RobertaForQuestionAnswering":
squad_dataset = load_dataset("squad_v2")
squad_dataset = load_dataset("squad_v2", trust_remote_code=True)
val_dataset = squad_dataset["validation"]
samples = val_dataset.select(range(batch_size))
@ -27,7 +27,8 @@ def prepare_text_inputs(model_name, batch_size=8):
}
return qa_real_test_inputs
elif model_name == "BertForSequenceClassification" or model_name == "RobertaForSequenceClassification":
yelp_dataset = load_dataset("fancyzhx/yelp_polarity")
yelp_dataset = load_dataset("fancyzhx/yelp_polarity",
trust_remote_code=True)
val_dataset = yelp_dataset["test"]
samples = val_dataset.select(range(batch_size))

View File

@ -1,5 +1,5 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets==2.14.6
datasets==3.1.0
evaluate
rouge_score

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
sentencepiece>=0.1.99
evaluate

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
SentencePiece~=0.1.99
evaluate

View File

@ -11,4 +11,4 @@ sentencepiece>=0.1.99
h5py~=3.12.1
rouge_score
nltk
datasets==2.14.6
datasets==3.1.0

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
protobuf
rouge_score

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
SentencePiece>=0.1.99

View File

@ -1,7 +1,7 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
transformers>=4.43.0
datasets==2.14.6
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -312,7 +312,8 @@ def main(args):
tokenizer.pad_token = tokenizer.eos_token
dataset_openweb = load_dataset("stas/openwebtext-10k",
cache_dir=args.dataset_path)
cache_dir=args.dataset_path,
trust_remote_code=True)
long_texts = get_long_texts(dataset_openweb) # generator
# get datapoints

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
sentencepiece>=0.1.99
evaluate

View File

@ -1,7 +1,7 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
transformers>=4.39.0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
sentencepiece

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
sentencepiece>=0.1.99
evaluate

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.15.0
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
protobuf
rouge_score

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
protobuf
rouge_score

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
protobuf
rouge_score

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
tiktoken==0.6.0

View File

@ -1,5 +1,5 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.6
datasets==3.1.0
evaluate
rouge_score

View File

@ -1,3 +1,3 @@
datasets~=2.14.6
datasets==3.1.0
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1,7 +1,7 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
transformers>=4.31.0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -1,5 +1,5 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score

View File

@ -1,5 +1,5 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
evaluate

View File

@ -1,7 +1,7 @@
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets==2.14.6
datasets==3.1.0
evaluate
rouge_score
sentencepiece==0.2.0

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets==2.14.5
datasets==3.1.0
rouge_score
sentencepiece>=0.1.99
evaluate

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
SentencePiece>=0.1.99

View File

@ -1,5 +1,5 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score

View File

@ -1,5 +1,5 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.16.1
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets==2.14.6
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -108,6 +108,7 @@ def load_dataset(args) -> datasets.Dataset:
'timeout': aiohttp.ClientTimeout(total=3600)
}
},
trust_remote_code=True,
)
return dataset

View File

@ -2,6 +2,6 @@
tensorrt_llm>=0.0.0.dev0
nemo-toolkit[all]==2.0.0rc1
megatron-core @ git+https://github.com/NVIDIA/Megatron-LM@core_r0.8.0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score

View File

@ -20,7 +20,7 @@ def create_trtllm_magpie_calibration_dataset(output_dir: str,
calib_size: int = 512) -> None:
from datasets import load_dataset
dataset = load_dataset(DATASET, split="train")
dataset = load_dataset(DATASET, split="train", trust_remote_code=True)
def transform(conversation):
value = '\n'.join(turn['value']

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
einops~=0.7.0

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
sentencepiece~=0.1.99
evaluate

View File

@ -1,11 +1,12 @@
import argparse
import json
import os
from typing import Any, Dict, List
from quickstart_advanced import add_llm_args, setup_llm
from transformers import AutoProcessor
from tensorrt_llm.inputs import load_image, load_video
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
default_video_loader)
example_images = [
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
@ -27,92 +28,32 @@ example_video_prompts = [
]
def prepare_vila(args, inputs):
def prepare_multimodal_inputs(model_dir: str,
model_type: str,
modality: str,
prompts: List[str],
media: List[str],
image_data_format: str = "pt",
num_frames: int = 8) -> List[Dict[str, Any]]:
def add_media_token(prompt, multi_modal_data):
mm_tokens = ""
if "image" in multi_modal_data:
for _ in multi_modal_data["image"]:
mm_tokens += "<image>"
elif "video" in multi_modal_data:
for _ in multi_modal_data["video"]:
mm_tokens += "<vila/video>"
return mm_tokens + prompt
inputs = []
if modality == "image":
inputs = default_image_loader(prompts, media, image_data_format)
elif modality == "video":
inputs = default_video_loader(prompts, media, image_data_format,
num_frames)
else:
raise ValueError(f"Unsupported modality: {modality}")
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
for input in inputs:
input["prompt"] = add_media_token(input["prompt"],
input["multi_modal_data"])
return inputs
def prepare_llava_next(args, inputs):
processor = AutoProcessor.from_pretrained(args.model_dir)
# Single-image inference chat template. For multi-image template,
# see https://huggingface.co/docs/transformers/en/model_doc/llava_next#multi-image-inference.
def apply_template(prompt, multimodal_data):
conversation = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image"
},
],
},
]
return processor.apply_chat_template(
conversation,
add_generation_prompt=True,
)
for input in inputs:
input["prompt"] = apply_template(input["prompt"],
input["multi_modal_data"])
return inputs
def prepare_qwen2_vl(args, inputs):
processor = AutoProcessor.from_pretrained(args.model_dir)
def apply_template(prompt, multimodal_data):
content = [{
"type": media_type
} for media_type, items in multimodal_data.items()
for _ in items] + [{
"type": "text",
"text": prompt
}]
conversation = [{"role": "user", "content": content}]
return processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
for input in inputs:
input["prompt"] = apply_template(input["prompt"],
input["multi_modal_data"])
return inputs
MODEL_TYPE_MAP = {
"llava_llama": prepare_vila,
"llava_next": prepare_llava_next,
"qwen2_vl": prepare_qwen2_vl,
"qwen2_5_vl": prepare_qwen2_vl,
}
def add_multimodal_args(parser):
parser.add_argument("--model_type",
type=str,
choices=MODEL_TYPE_MAP.keys(),
choices=INPUT_FORMATTER_MAP.keys(),
help="Model type.")
parser.add_argument("--modality",
type=str,
@ -150,50 +91,16 @@ def main():
llm, sampling_params = setup_llm(args)
image_format = "pt" # ["pt", "pil"]
if args.modality == "image":
prompts = args.prompt if args.prompt else example_image_prompts
images = args.media if args.media else example_images
if len(images) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
images = [images]
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"image": [
load_image(i, format=image_format, device="cuda")
for i in image
] if isinstance(image, list) else
[load_image(image, format=image_format, device="cuda")]
}
} for prompt, image in zip(prompts, images)]
elif args.modality == "video":
prompts = args.prompt if args.prompt else example_video_prompts
videos = args.media if args.media else example_videos
if len(videos) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
videos = [videos]
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"video": [
load_video(
i, args.num_frames, format=image_format, device="cuda")
for i in video
] if isinstance(video, list) else [
load_video(video,
args.num_frames,
format=image_format,
device="cuda")
]
}
} for prompt, video in zip(prompts, videos)]
if args.model_type is not None:
model_type = args.model_type
else:
raise ValueError(f"Unsupported modality: {args.modality}")
model_type = json.load(
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in INPUT_FORMATTER_MAP, f"Unsupported model_type: {model_type}"
model_type = json.load(open(os.path.join(llm._hf_model_dir,
'config.json')))['model_type']
assert model_type in MODEL_TYPE_MAP, f"Unsupported model_type: {model_type}"
inputs = MODEL_TYPE_MAP[model_type](args, inputs)
inputs = prepare_multimodal_inputs(args.model_dir, model_type,
args.modality, args.prompt, args.media,
image_format, args.num_frames)
outputs = llm.generate(inputs, sampling_params)

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets>=2.14.4
datasets==3.1.0
nemo-toolkit[all]==2.0.0rc1
rouge_score
transformers_stream_generator==0.0.4

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.16.0
datasets==3.1.0
evaluate
rouge_score
transformers>=4.40.1

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.dev0
datasets~=2.16.0
datasets==3.1.0
evaluate
rouge_score
transformers>=4.45.0

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.16.0
datasets==3.1.0
evaluate
rouge_score
transformers-stream-generator

View File

@ -5,7 +5,7 @@ flax>=0.8.2
jax~=0.4.23
orbax-checkpoint==0.5.7
transformers>=4.40.0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
sentencepiece

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
sentencepiece>=0.1.99
evaluate

View File

@ -110,7 +110,8 @@ def main(args):
dataset = load_dataset(dataset_name,
dataset_revision,
cache_dir=args.dataset_cache_dir,
split=dataset_split)
split=dataset_split,
trust_remote_code=True)
dataset = dataset.shuffle(args.random_seed)
max_batch_size = args.batch_size

View File

@ -1,7 +1,7 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
tiktoken
datasets
datasets==3.1.0
kaldialign
openai-whisper
librosa

View File

@ -564,7 +564,8 @@ if __name__ == '__main__':
normalizer = EnglishTextNormalizer()
dataset = load_dataset(args.dataset,
args.dataset_name,
split=args.dataset_split)
split=args.dataset_split,
trust_remote_code=True)
if args.enable_warmup:
results, total_duration = decode_dataset(
model,

View File

@ -1,5 +1,4 @@
-r requirements.txt
datasets==2.19.2
einops
graphviz
mypy

View File

@ -31,6 +31,8 @@ pydantic>=2.9.1
pillow==10.3.0
wheel<=0.45.1
optimum
# evaluate needs datasets>=2.0.0 which triggers datasets>3.1.0 which is not stable: https://github.com/huggingface/datasets/issues/7467
datasets==3.1.0
evaluate
mpmath>=1.3.0
click

View File

@ -437,10 +437,7 @@ class Qwen2VLModelBase(PreTrainedModel):
inputs_embeds=input_embeds,
return_context_logits=return_context_logits,
mrope_config=mrope_config)
logger.debug(
f"output_ids: {(output_prob if output_prob.dim() == 2 else output_prob.unsqueeze(0)).argmax(dim=1).tolist()}"
)
logger.info(f'output shape: {output_prob.shape}')
logger.debug(f'output shape: {output_prob.shape}')
return output_prob

View File

@ -10,6 +10,7 @@ from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup,
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
from tensorrt_llm.bench.build.build import get_model_config
# isort: off
from tensorrt_llm.bench.benchmark.utils.general import (
@ -21,7 +22,8 @@ from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.bench.dataclasses.reporting import ReportUtility
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
initialize_tokenizer)
initialize_tokenizer,
update_metadata_for_multimodal)
from tensorrt_llm.llmapi import LLM, CapacitySchedulerPolicy
from tensorrt_llm.logger import logger
from tensorrt_llm.sampling_params import SamplingParams
@ -92,12 +94,26 @@ from tensorrt_llm.sampling_params import SamplingParams
required=False,
help="Pass in a dataset file for parsing instead of stdin.",
)
@optgroup.option(
"--modality",
type=click.Choice(["image", "video"]),
default=None,
help="Modality of the multimodal requests.",
)
@optgroup.option(
"--max_input_len",
type=int,
default=4096,
help=
"Maximum input sequence length to use for multimodal models. This is used only when --modality "
"is specified since the actual number of vision tokens is unknown before the model is run.",
)
@optgroup.option(
"--num_requests",
type=int,
default=0,
help=
"Number of requests to cap benchmark run at. If not specified or set to 0, it will be the"
"Number of requests to cap benchmark run at. If not specified or set to 0, it will be the "
"length of dataset.",
)
@optgroup.option(
@ -194,6 +210,9 @@ def throughput_command(
engine_dir: Path = params.pop("engine_dir")
concurrency: int = params.pop("concurrency")
backend: str = params.get("backend")
modality: str = params.pop("modality")
max_input_len: int = params.pop("max_input_len")
model_type = get_model_config(model, checkpoint_path).model_type
# Reporting options
report_json: Path = params.pop("report_json")
@ -209,14 +228,23 @@ def throughput_command(
# Dataset Loading and Preparation
with open(dataset_path, "r") as dataset:
metadata, requests = create_dataset_from_stream(
tokenizer, dataset, num_requests=num_requests)
tokenizer,
dataset,
num_requests=num_requests,
model_dir=checkpoint_path,
model_type=model_type,
modality=modality,
max_input_seq_len_for_multimodal=max_input_len)
metadata.dataset_path = dataset_path
params["target_input_len"] = params.get(
"target_input_len") or metadata.avg_isl
params["target_output_len"] = params.get(
"target_output_len") or metadata.avg_osl
if modality is None:
# Log dataset info
# NOTE: This table is only accurate for non-multimodal models.
# The accurate table for multimodal models will be logged after the benchmark is done.
logger.info(metadata.get_summary_for_print())
# Engine configuration parsing
@ -294,8 +322,12 @@ def throughput_command(
warmup_dataset = generate_warmup_dataset(requests, warmup)
logger.info("Running warmup.")
asyncio.run(
async_benchmark(llm, sampling_params, warmup_dataset, False,
concurrency))
async_benchmark(llm,
sampling_params,
warmup_dataset,
False,
concurrency,
modality=modality))
# WAR: IterationResult is a singleton tied to the executor.
# Since the benchmark calls asyncio.run() multiple times (e.g., during warmup),
# we must reset it to ensure it attaches to the correct event loop.
@ -304,10 +336,19 @@ def throughput_command(
with iteration_writer.capture():
statistics = asyncio.run(
async_benchmark(llm, sampling_params, requests, streaming,
concurrency, iteration_writer.full_address))
async_benchmark(llm,
sampling_params,
requests,
streaming,
concurrency,
iteration_writer.full_address,
modality=modality))
logger.info(f"Benchmark done. Reporting results...")
if modality is not None:
# For multimodal models, we need to update the metadata with the correct input lengths
metadata = update_metadata_for_multimodal(metadata, statistics)
report_utility = ReportUtility(statistics, metadata, runtime_config,
logger, kwargs, streaming)
if report_json:

View File

@ -23,7 +23,8 @@ class LlmManager:
llm: LLM,
outbox: asyncio.Queue[PerfItemTuple],
streaming: bool,
concurrency: int = -1) -> None:
concurrency: int = -1,
modality: Optional[str] = None) -> None:
self.llm = llm
self._inbox: asyncio.Queue[Tuple[InferenceRequest,
SamplingParams]] = asyncio.Queue()
@ -38,6 +39,7 @@ class LlmManager:
concurrency) if concurrency > 0 else None
self.streaming = streaming
self.request_seen = asyncio.Event()
self.modality = modality
async def process_request(self, request: InferenceRequest,
sampling_params: SamplingParams):
@ -50,7 +52,7 @@ class LlmManager:
time_on_first_token = None
# Schedule the request in the LLM API (asynchronously)
output: RequestOutput = self.llm.generate_async(
request.input_ids,
request.input_ids if self.modality is None else request.prompt,
sampling_params=sampling_params,
streaming=self.streaming)
if self.streaming:
@ -70,7 +72,7 @@ class LlmManager:
start_timestamp=request_start_timestamp,
end_timestamp=response_end_timestamp,
request_id=response.request_id,
num_input_tokens=len(request.input_ids),
num_input_tokens=len(output.prompt_token_ids),
response_is_final=response.finished,
error=False,
tokens=tokens,
@ -201,6 +203,7 @@ async def async_benchmark(
streaming: bool,
concurrency: int = -1,
iteration_log_addr: str = None,
modality: Optional[str] = None,
) -> StatsKeeper:
outbox = asyncio.Queue()
statistics = StatsKeeper()
@ -208,7 +211,11 @@ async def async_benchmark(
try:
logger.info("Starting benchmarking async task.")
backend = LlmManager(llm, outbox, streaming, concurrency=concurrency)
backend = LlmManager(llm,
outbox,
streaming,
concurrency=concurrency,
modality=modality)
backend.run(iteration_addr=iteration_log_addr)
enqueue_task = asyncio.create_task(

View File

@ -116,6 +116,7 @@ class ModelConfig(BaseModel):
setting calculation.
"""
name: str
model_type: str
param_count: int
num_hidden_layers: int = Field(validation_alias=AliasChoices(
"num_hidden_layers",

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
from typing import Any, List, Optional, Union
from pydantic import (AliasChoices, BaseModel, Field, computed_field,
model_validator)
@ -17,7 +17,7 @@ class BenchmarkEnvironment(BaseModel):
class InferenceRequest(BaseModel):
task_id: int
prompt: Optional[str] = None
prompt: Optional[Union[str, Any]] = None
output_tokens: int
input_ids: Optional[List[int]] = Field(
alias=AliasChoices("input_ids", "logits"))

View File

@ -1,12 +1,36 @@
import json
from functools import partial
from typing import List, TextIO, Tuple
from typing import Any, Dict, List, TextIO, Tuple
from transformers import AutoTokenizer, PreTrainedTokenizer
from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata,
InferenceRequest)
from tensorrt_llm.bench.dataclasses.statistics import PercentileStats
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
default_video_loader)
def prepare_multimodal_inputs(model_dir: str,
model_type: str,
modality: str,
prompts: List[str],
media: List[str],
image_data_format: str = "pt",
num_frames: int = 8) -> List[Dict[str, Any]]:
inputs = []
if modality == "image":
inputs = default_image_loader(prompts, media, image_data_format)
elif modality == "video":
inputs = default_video_loader(prompts, media, image_data_format,
num_frames)
else:
raise ValueError(f"Unsupported modality: {modality}")
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
return inputs
def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer:
@ -36,6 +60,10 @@ def create_dataset_from_stream(
max_input_length: int = 0,
max_output_length: int = 0,
num_requests: int = 0,
model_dir: str = None,
model_type: str = None,
modality: str = None,
max_input_seq_len_for_multimodal: int = 4096,
) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
"""Generate metadata and a list of requests to drive benchmarking.
@ -83,13 +111,30 @@ def create_dataset_from_stream(
# Each line should be a complete JSON dictionary with no indentation
# or newline characters.
data = json.loads(line)
if modality is not None:
# Multimodal data
assert modality in [
"image", "video"
], f"Modality must be one of ['image', 'video'] but got {modality}."
prompt = data.get("prompt") # cannot be None
media_paths = data.get("media_paths", None)
inputs = prepare_multimodal_inputs(
model_dir,
model_type,
modality,
prompts=[prompt],
media=media_paths) # list of dicts
logits = None # cannot tokenize multi-modal data, handled by preprocessor
prompt = inputs[0]
else:
logits = data.get("input_ids", data.get("logits", None))
prompt = data.get("prompt", None)
task_id = data["task_id"]
osl = data["output_tokens"]
# If the request comes in with logits, just use the provided.
# Otherwise we need to tokenize it.
logits = tokenize(prompt)["input_ids"] if logits is None else logits
task_id = data["task_id"]
osl = data["output_tokens"]
request = InferenceRequest(
task_id=task_id,
@ -97,8 +142,13 @@ def create_dataset_from_stream(
output_tokens=output_limiter(osl),
input_ids=logits,
)
all_isl.append(len(logits))
all_osl.append(osl)
if modality is not None:
cur_isl = max_input_seq_len_for_multimodal # NOTE: actual sequence length is unknown until the model is run
all_isl.append(cur_isl)
all_seq_len.append(cur_isl + osl)
else:
all_isl.append(len(logits))
all_seq_len.append(len(logits) + osl)
dataset.append(request)
@ -115,3 +165,31 @@ def create_dataset_from_stream(
)
return metadata, dataset
def update_metadata_for_multimodal(metadata, statistics) -> DatasetMetadata:
"""Update the metadata from benchmark statistics. Only used for multimodal models.
Args:
metadata (DatasetMetadata): The metadata to update.
statistics (StatsKeeper): The statistics to update the metadata with.
Returns:
DatasetMetadata: The updated metadata.
"""
all_isl = []
all_osl = []
all_seq_len = []
for request in statistics.requests.values():
all_isl.append(request.num_input_tokens)
all_osl.append(request.num_total_output_tokens)
all_seq_len.append(request.num_input_tokens +
request.num_total_output_tokens)
isl_stats = PercentileStats.from_iterable(all_isl)
osl_stats = PercentileStats.from_iterable(all_osl)
seq_len_stats = PercentileStats.from_iterable(all_seq_len)
metadata.isl_stats = isl_stats
metadata.osl_stats = osl_stats
metadata.seq_len_stats = seq_len_stats
return metadata

View File

@ -36,7 +36,10 @@ class CnnDailymail(Evaluator):
system_prompt: Optional[str] = None):
super().__init__(apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
self.data = datasets.load_dataset(dataset_path, "3.0.0", split="test")
self.data = datasets.load_dataset(dataset_path,
"3.0.0",
split="test",
trust_remote_code=True)
self.data = self.data.shuffle(random_seed)
if num_samples is None:
self.num_samples = self.data.num_rows

View File

@ -1,10 +1,15 @@
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
from .registry import (ExtraProcessedInputs, InputProcessor,
create_input_processor, register_input_processor)
from .utils import load_image, load_video
from .utils import (INPUT_FORMATTER_MAP, default_image_loader,
default_video_loader, format_llava_next_input,
format_qwen2_vl_input, format_vila_input, load_image,
load_video)
__all__ = [
"PromptInputs", "prompt_inputs", "TextPrompt", "TokensPrompt",
"InputProcessor", "create_input_processor", "register_input_processor",
"ExtraProcessedInputs", "load_image", "load_video"
"ExtraProcessedInputs", "load_image", "load_video", "INPUT_FORMATTER_MAP",
"default_image_loader", "default_video_loader", "format_vila_input",
"format_llava_next_input", "format_qwen2_vl_input"
]

View File

@ -6,6 +6,7 @@ import requests
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import AutoProcessor
def load_image(image: str,
@ -67,3 +68,158 @@ def load_video(
device=device) if format == "pt" else frames[index]
for index in indices if index in frames
]
"""
VLM input preparation.
"""
def format_vila_input(model_dir, inputs):
"""
This function formats the input for the VILA/NVILA VL model.
Arguments:
model_dir: The directory of the model to load any preprocessor.
inputs: The list of inputs to format.
Returns:
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
"""
def add_media_token(prompt, multi_modal_data):
mm_tokens = ""
if "image" in multi_modal_data:
for _ in multi_modal_data["image"]:
mm_tokens += "<image>"
elif "video" in multi_modal_data:
for _ in multi_modal_data["video"]:
mm_tokens += "<vila/video>"
return mm_tokens + prompt
for input in inputs:
input["prompt"] = add_media_token(input["prompt"],
input["multi_modal_data"])
return inputs
def format_llava_next_input(model_dir, inputs):
"""
This function formats the input for the Llava Next VL model.
Arguments:
model_dir: The directory of the model to load any preprocessor.
inputs: The list of inputs to format.
Returns:
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
"""
processor = AutoProcessor.from_pretrained(model_dir)
# Single-image inference chat template. For multi-image template,
# see https://huggingface.co/docs/transformers/en/model_doc/llava_next#multi-image-inference.
def apply_template(prompt, multimodal_data):
conversation = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image"
},
],
},
]
return processor.apply_chat_template(
conversation,
add_generation_prompt=True,
)
for input in inputs:
input["prompt"] = apply_template(input["prompt"],
input["multi_modal_data"])
return inputs
def format_qwen2_vl_input(model_dir, inputs):
"""
This function formats the input for the Qwen2/Qwen2.5 VL model.
Arguments:
model_dir: The directory of the model to load any preprocessor.
inputs: The list of inputs to format.
Returns:
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
"""
processor = AutoProcessor.from_pretrained(model_dir)
def apply_template(prompt, multimodal_data):
content = [{
"type": media_type
} for media_type, items in multimodal_data.items()
for _ in items] + [{
"type": "text",
"text": prompt
}]
conversation = [{"role": "user", "content": content}]
# print(conversation)
return processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
for input in inputs:
input["prompt"] = apply_template(input["prompt"],
input["multi_modal_data"])
return inputs
def default_image_loader(prompts, images, image_data_format="pt"):
if len(images) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
images = [images]
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"image": [
load_image(i, format=image_data_format, device="cuda")
for i in image
] if isinstance(image, list) else
[load_image(image, format=image_data_format, device="cuda")]
}
} for prompt, image in zip(prompts, images)]
return inputs
def default_video_loader(prompts, videos, image_data_format="pt", num_frames=8):
if len(videos) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
videos = [videos]
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"video": [
load_video(
i, num_frames, format=image_data_format, device="cuda")
for i in video
] if isinstance(video, list) else [
load_video(
video, num_frames, format=image_data_format, device="cuda")
]
}
} for prompt, video in zip(prompts, videos)]
return inputs
INPUT_FORMATTER_MAP = {
"llava_llama": format_vila_input,
"llava_next": format_llava_next_input,
"qwen2_vl": format_qwen2_vl_input,
"qwen2_5_vl": format_qwen2_vl_input,
}

View File

@ -306,6 +306,7 @@ def load_calib_dataset(dataset_name_or_dir: str,
dataset = load_dataset(dataset_name_or_dir,
name=config_name,
split=split,
trust_remote_code=trust_remote_code,
**kwargs)
return dataset[key]

View File

@ -384,20 +384,26 @@ def get_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
dataset = load_dataset(
"json",
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
split="train")
split="train",
trust_remote_code=True)
dataset = dataset["text"][:calib_size]
elif "scienceqa" in dataset_name_or_dir.lower(
) or "science_qa" in dataset_name_or_dir.lower():
if os.path.isdir(dataset_name_or_dir):
dataset = load_dataset(dataset_name_or_dir, split="train")
dataset = load_dataset(dataset_name_or_dir,
split="train",
trust_remote_code=True)
else:
dataset = load_dataset("derek-thomas/ScienceQA", split="train")
dataset = load_dataset("derek-thomas/ScienceQA",
split="train",
trust_remote_code=True)
dataset = dataset.select(range(calib_size))
elif "cnn_dailymail" in dataset_name_or_dir:
dataset = load_dataset(
dataset_name_or_dir,
name="3.0.0",
split="train",
trust_remote_code=True,
)
dataset = dataset["article"][:calib_size]
elif os.path.isdir(dataset_name_or_dir):
@ -405,7 +411,9 @@ def get_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
"assuming the calibration data are in the train split and text column."
)
dataset = load_dataset(dataset_name_or_dir, split="train")
dataset = load_dataset(dataset_name_or_dir,
split="train",
trust_remote_code=True)
dataset = dataset["text"][:calib_size]
else:
raise NotImplementedError(
@ -993,22 +1001,29 @@ def get_nemo_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
dataset = load_dataset(
"json",
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
split="train")
split="train",
trust_remote_code=True)
text_column = "text"
elif "wikitext" in dataset_name_or_dir:
dataset = load_dataset(dataset_name_or_dir,
"wikitext-103-v1",
split="train")
split="train",
trust_remote_code=True)
text_column = "text"
elif "cnn_dailymail" in dataset_name_or_dir:
dataset = load_dataset(dataset_name_or_dir, name="3.0.0", split="train")
dataset = load_dataset(dataset_name_or_dir,
name="3.0.0",
split="train",
trust_remote_code=True)
text_column = "article"
elif os.path.isdir(dataset_name_or_dir):
logger.info(
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
"assuming the calibration data are in the train split and text column."
)
dataset = load_dataset(dataset_name_or_dir, split="train")
dataset = load_dataset(dataset_name_or_dir,
split="train",
trust_remote_code=True)
text_column = "text"
else:
raise NotImplementedError(

View File

@ -366,8 +366,11 @@ def test_tokenizer_decode_incrementally(tokenizer_dir: str, threshold: float):
num_samples = 100
cnn_dailymail = datasets.load_dataset(cnn_dailymail_path,
name='3.0.0',
split='train')
alpaca_chinese = datasets.load_dataset(alpaca_chinese_path, split='train')
split='train',
trust_remote_code=True)
alpaca_chinese = datasets.load_dataset(alpaca_chinese_path,
split='train',
trust_remote_code=True)
dataset = cnn_dailymail['article'][:num_samples // 2] + alpaca_chinese[
'output_zh'][:num_samples // 2]