feat: adding multimodal (only image for now) support in trtllm-bench (#3490)

* feat: adding multimodal (only image for now) support in trtllm-bench

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* fix: add  in load_dataset() calls to maintain the v2.19.2 behavior

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* re-adding prompt_token_ids and using that for prompt_len

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* updating the datasets version in examples as well

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* api changes are not needed

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* moving datasets requirement and removing a missed api change

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* addressing review comments

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* refactoring the quickstart example

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

---------

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
rakib-hasan 2025-04-17 16:06:16 -07:00 committed by GitHub
parent 26ebd95302
commit ff3b741045
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
63 changed files with 703 additions and 258 deletions

View File

@ -1,12 +1,16 @@
import logging
import random
import re
import tempfile
from typing import Optional
import click
from datasets import load_dataset
from PIL import Image
from pydantic import BaseModel, model_validator
from utils.utils import dataset_dump, get_norm_dist_lengths, print_dataset
from utils.utils import (get_norm_dist_lengths, multimodal_dataset_dump,
print_multimodal_dataset, print_text_dataset,
text_dataset_dump)
def validate_output_len_dist(ctx, param, value):
@ -31,8 +35,10 @@ class DatasetConfig(BaseModel):
"""Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits."""
split: Optional[str]
"""The dataset dictionary used for the input sentence."""
input_key: str
input_key: Optional[str] = None
"""The dataset dictionary key used for the prompt of the input sentence. Must not be set when prompt is set."""
image_key: Optional[str] = None
"""The dataset dictionary key used for the images."""
prompt_key: Optional[str] = None
"""The prompt sentence to be added to the input sentence. Must not be set when prompt_key is set."""
prompt: Optional[str] = None
@ -75,6 +81,20 @@ class DatasetConfig(BaseModel):
f"{req.keys()}")
return req[self.input_key]
def get_images(self, req):
"""Get the images from the given request."""
image_keys = [self.image_key
] + [f"{self.image_key}_{i}" for i in range(1, 8)]
assert any(key in req for key in image_keys), (
f"Dataset {self.name} does not have key '{self.image_key}'. "
"Please set --dataset-image-key to one of the available keys: "
f"{req.keys()}")
images = []
for key in image_keys:
if key in req and req[key] is not None:
images.append(req[key])
return images
def get_output(self, req):
"""Get the output sentence from the given request."""
if self.output_key is None:
@ -105,7 +125,8 @@ def load_dataset_from_hf(dataset_config: DatasetConfig):
dataset = iter(
load_dataset(*dataset_config.query,
split=dataset_config.split,
streaming=True))
streaming=True,
trust_remote_code=True))
except ValueError as e:
if "Config" in e:
e += "\n Please add the config name to the dataset config yaml."
@ -130,9 +151,12 @@ def load_dataset_from_hf(dataset_config: DatasetConfig):
required=True,
help=f"Split of the dataset to use.")
@click.option("--dataset-input-key",
required=True,
type=str,
help=f"The dataset dictionary key for input.")
@click.option("--dataset-image-key",
type=str,
default="image",
help=f"The dataset dictionary key for images.")
@click.option("--dataset-prompt-key",
type=str,
default=None,
@ -181,8 +205,39 @@ def dataset(root_args, **kwargs):
output_lens = []
task_ids = []
req_cnt = 0
modality = None
multimodal_texts = []
multimodal_image_paths = []
for req in load_dataset_from_hf(dataset_config):
# input
if any(key in req for key in ['image', 'image_1', 'video']):
# multimodal input
if 'video' in req and req['video'] is not None:
assert "Not supported yet"
assert kwargs['output_len_dist'] is not None, (
"Output length distribution must be set for multimodal requests."
)
modality = 'image'
text = dataset_config.get_prompt(req)
images = dataset_config.get_images(req)
image_paths = []
for image in images:
if image is not None:
if isinstance(image, str):
image_paths.append(image)
elif isinstance(image, Image.Image):
with tempfile.NamedTemporaryFile(
suffix=".jpg", delete=False) as tmp_file:
logging.debug(f"Saving image to {tmp_file.name}")
image = image.convert("RGB")
image.save(tmp_file, "JPEG")
filepath = tmp_file.name
image_paths.append(filepath)
else:
raise ValueError(f"Invalid image path: {image}")
multimodal_texts.append(text)
multimodal_image_paths.append(image_paths)
else:
# text input
prompt = dataset_config.get_prompt(
req) + ' ' + dataset_config.get_input(req)
logging.debug(f"Input sequence: {prompt}")
@ -195,7 +250,9 @@ def dataset(root_args, **kwargs):
# output if fetch from golden
if kwargs['output_len_dist'] is None:
output_lens.append(
len(root_args.tokenizer.encode(dataset_config.get_output(req))))
len(
root_args.tokenizer.encode(
dataset_config.get_output(req))))
# lora task id
task_id = root_args.task_id
@ -208,21 +265,44 @@ def dataset(root_args, **kwargs):
if kwargs['num_requests'] and req_cnt >= kwargs['num_requests']:
break
if kwargs['num_requests'] and len(input_ids) < kwargs['num_requests']:
if kwargs['num_requests'] and (len(input_ids) if modality is None else len(
multimodal_texts)) < kwargs['num_requests']:
logging.warning(
"Number of requests is smaller than the num-requests user set.")
f"Number of requests={len(input_ids) if modality is None else len(multimodal_texts)} is"
f" smaller than the num-requests user set={kwargs['num_requests']}."
)
# output if randomized
if kwargs['output_len_dist'] is not None:
osl_mean, osl_stdev = kwargs['output_len_dist']
output_lens = get_norm_dist_lengths(osl_mean, osl_stdev, len(input_ids),
output_lens = get_norm_dist_lengths(
osl_mean, osl_stdev,
len(input_ids) if modality is None else len(multimodal_texts),
root_args.random_seed)
logging.debug(f"Input lengths: {[len(i) for i in input_ids]}")
logging.debug(f"Output lengths: {output_lens}")
if modality is not None:
logging.debug(f"Modality: {modality}")
if modality is not None:
if not root_args.std_out:
dataset_dump(
multimodal_dataset_dump(
multimodal_texts, multimodal_image_paths, output_lens, task_ids,
{
"workload_type": "dataset",
"tokenizer": root_args.tokenizer.__class__.__name__,
"num_requests": len(task_ids),
"max_output_len": max(output_lens)
}, root_args.output)
else:
print_multimodal_dataset(
multimodal_texts,
multimodal_image_paths,
output_lens,
)
else:
if not root_args.std_out:
text_dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "dataset",
"tokenizer": root_args.tokenizer.__class__.__name__,
@ -231,7 +311,7 @@ def dataset(root_args, **kwargs):
"max_output_len": max(output_lens)
}, root_args.output)
else:
print_dataset(
print_text_dataset(
input_ids,
output_lens,
)

View File

@ -1,8 +1,9 @@
import random
import click
from utils.utils import (dataset_dump, gen_random_tokens, get_norm_dist_lengths,
get_unif_dist_lengths, print_dataset)
from utils.utils import (gen_random_tokens, get_norm_dist_lengths,
get_unif_dist_lengths, print_text_dataset,
text_dataset_dump)
@click.command()
@ -57,7 +58,7 @@ def token_norm_dist(root_args, **kwargs):
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
if not root_args.std_out:
dataset_dump(
text_dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "token-norm-dist",
"input_mean": kwargs['input_mean'],
@ -70,7 +71,7 @@ def token_norm_dist(root_args, **kwargs):
"max_output_len": max_output_len
}, root_args.output)
else:
print_dataset(
print_text_dataset(
input_ids,
output_lens,
)
@ -127,7 +128,7 @@ def token_unif_dist(root_args, **kwargs):
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
if not root_args.std_out:
dataset_dump(
text_dataset_dump(
input_lens, input_ids, output_lens, task_ids, {
"workload_type": "token-unif-dist",
"input_min": kwargs['input_min'],
@ -140,7 +141,7 @@ def token_unif_dist(root_args, **kwargs):
"max_output_len": max_output_len
}, root_args.output)
else:
print_dataset(
print_text_dataset(
input_ids,
output_lens,
)

View File

@ -2,22 +2,29 @@ import json
import math
import os
import random
from typing import List
from typing import List, Union
import numpy as np
from pydantic import BaseModel
class Sample(BaseModel):
class TextSample(BaseModel):
input_len: int
input_ids: List[int]
output_len: int
task_id: int
class MultimodalSample(BaseModel):
task_id: int
prompt: str
media_paths: List[str]
output_len: int
class Workload(BaseModel):
metadata: dict
samples: List[Sample] = []
samples: List[Union[TextSample, MultimodalSample]] = []
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
@ -33,12 +40,12 @@ class Workload(BaseModel):
self.metadata.setdefault('workload_name', workload_name)
def dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
def text_dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
output_file):
samples = []
for i in range(len(input_ids)):
samples.append(
Sample(input_len=input_lens[i],
TextSample(input_len=input_lens[i],
input_ids=input_ids[i],
output_len=output_lens[i],
task_id=task_ids[i]))
@ -48,7 +55,22 @@ def dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
json.dump(workload.model_dump(), f)
def print_dataset(input_ids, output_lens):
def multimodal_dataset_dump(multimodal_texts, multimodal_image_paths,
output_lens, task_ids, metadata, output_file):
samples = []
for i in range(len(multimodal_texts)):
samples.append(
MultimodalSample(task_id=task_ids[i],
prompt=multimodal_texts[i],
media_paths=multimodal_image_paths[i],
output_len=output_lens[i]))
workload = Workload(metadata=metadata, samples=samples)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
json.dump(workload.model_dump(), f)
def print_text_dataset(input_ids, output_lens):
for i, input_tokens in enumerate(input_ids):
d = {
"task_id": i,
@ -58,6 +80,19 @@ def print_dataset(input_ids, output_lens):
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
def print_multimodal_dataset(multimodal_texts, multimodal_image_paths,
output_lens):
for i, (text, image_paths) in enumerate(
zip(multimodal_texts, multimodal_image_paths)):
d = {
"task_id": i,
"prompt": text,
"media_paths": image_paths,
"output_tokens": output_lens[i]
}
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
def get_list_of_delays(delay_dist, mean_time_bet_reqs, num_reqs, random_seed):
if delay_dist == "constant":
delays = [mean_time_bet_reqs] * num_reqs

View File

@ -475,6 +475,115 @@ Total Latency (ms): 18563.6825
```
#### Running multi-modal models in the PyTorch Workflow
To benchmark multi-modal models with PyTorch workflow, you can follow the similar approach as above.
First, prepare the dataset:
```
python ./benchmarks/cpp/prepare_dataset.py \
--tokenizer Qwen/Qwen2-VL-2B-Instruct \
--stdout \
dataset \
--dataset-name lmms-lab/MMMU \
--dataset-split test \
--dataset-image-key image \
--dataset-prompt-key question \
--num-requests 10 \
--output-len-dist 128,5 > mm_data.jsonl
```
It will download the media files to `/tmp` directory and prepare the dataset with their paths. Note that the `prompt` fields are texts and not tokenized ids. This is due to the fact that
the `prompt` and the media (image/video) are processed by a preprocessor for multimodal files.
Sample dataset for multimodal:
```
{"task_id":0,"prompt":"Brahma Industries sells vinyl replacement windows to home improvement retailers nationwide. The national sales manager believes that if they invest an additional $25,000 in advertising, they would increase sales volume by 10,000 units. <image 1> What is the total contribution margin?","media_paths":["/tmp/tmp9so41y3r.jpg"],"output_tokens":126}
{"task_id":1,"prompt":"Let us compute for the missing amounts under work in process inventory, what is the cost of goods manufactured? <image 1>","media_paths":["/tmp/tmpowsrb_f4.jpg"],"output_tokens":119}
{"task_id":2,"prompt":"Tsuji is reviewing the price of a 3-month Japanese yen/U.S. dollar currency futures contract, using the currency and interest rate data shown below. Because the 3-month Japanese interest rate has just increased to .50%, Itsuji recognizes that an arbitrage opportunity exists nd decides to borrow $1 million U.S. dollars to purchase Japanese yen. Calculate the yen arbitrage profit from Itsuji's strategy, using the following data: <image 1> ","media_paths":["/tmp/tmpxhdvasex.jpg"],"output_tokens":126}
...
```
Run the benchmark:
```
trtllm-bench --model Qwen/Qwen2-VL-2B-Instruct \
throughput \
--dataset mm_data.jsonl \
--backend pytorch \
--num_requests 10 \
--max_batch_size 4 \
--modality image
```
Sample output:
```
===========================================================
= REQUEST DETAILS
===========================================================
Number of requests: 10
Number of concurrent requests: 5.3019
Average Input Length (tokens): 411.6000
Average Output Length (tokens): 128.7000
===========================================================
= WORLD + RUNTIME INFORMATION
===========================================================
TP Size: 1
PP Size: 1
EP Size: None
Max Runtime Batch Size: 4
Max Runtime Tokens: 12288
Scheduling Policy: GUARANTEED_NO_EVICT
KV Memory Percentage: 90.00%
Issue Rate (req/sec): 1.4117E+17
===========================================================
= PERFORMANCE OVERVIEW
===========================================================
Request Throughput (req/sec): 1.4439
Total Output Throughput (tokens/sec): 185.8351
Per User Output Throughput (tokens/sec/user): 38.1959
Per GPU Output Throughput (tokens/sec/gpu): 185.8351
Total Token Throughput (tokens/sec): 780.1607
Total Latency (ms): 6925.4963
Average request latency (ms): 3671.8441
-- Request Latency Breakdown (ms) -----------------------
[Latency] P50 : 3936.3022
[Latency] P90 : 5514.4701
[Latency] P95 : 5514.4701
[Latency] P99 : 5514.4701
[Latency] MINIMUM: 2397.1047
[Latency] MAXIMUM: 5514.4701
[Latency] AVERAGE: 3671.8441
===========================================================
= DATASET DETAILS
===========================================================
Dataset Path: /workspaces/tensorrt_llm/mm_data.jsonl
Number of Sequences: 10
-- Percentiles statistics ---------------------------------
Input Output Seq. Length
-----------------------------------------------------------
MIN: 167.0000 119.0000 300.0000
MAX: 1059.0000 137.0000 1178.0000
AVG: 411.6000 128.7000 540.3000
P50: 299.0000 128.0000 427.0000
P90: 1059.0000 137.0000 1178.0000
P95: 1059.0000 137.0000 1178.0000
P99: 1059.0000 137.0000 1178.0000
===========================================================
```
**Notes and Limitations**:
- Only image datasets are supported for now.
- `--output-len-dist` is a required argument for multimodal datasets.
- Tokenizer is unused during the prepare step but it is still a required argument.
- Since the images are converted to tokens when the model is run, `trtllm-bench` uses a default large value for the maximum input sequence length when setting up the execution settings.
You can also modify the behavior by specifying a different value with the flag `--max_input_len` that suits your use-case.
#### Quantization in the PyTorch Flow
In order to run a quantized run with `trtllm-bench` utilizing the PyTorch flow, you will need to use a pre-quantized

View File

@ -17,7 +17,7 @@ def prepare_text_inputs(model_name, batch_size=8):
f"HF_DATASETS_OFFLINE inside function: {datasets.config.HF_DATASETS_OFFLINE}"
)
if model_name == "BertForQuestionAnswering" or model_name == "RobertaForQuestionAnswering":
squad_dataset = load_dataset("squad_v2")
squad_dataset = load_dataset("squad_v2", trust_remote_code=True)
val_dataset = squad_dataset["validation"]
samples = val_dataset.select(range(batch_size))
@ -27,7 +27,8 @@ def prepare_text_inputs(model_name, batch_size=8):
}
return qa_real_test_inputs
elif model_name == "BertForSequenceClassification" or model_name == "RobertaForSequenceClassification":
yelp_dataset = load_dataset("fancyzhx/yelp_polarity")
yelp_dataset = load_dataset("fancyzhx/yelp_polarity",
trust_remote_code=True)
val_dataset = yelp_dataset["test"]
samples = val_dataset.select(range(batch_size))

View File

@ -1,5 +1,5 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets==2.14.6
datasets==3.1.0
evaluate
rouge_score

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
sentencepiece>=0.1.99
evaluate

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
SentencePiece~=0.1.99
evaluate

View File

@ -11,4 +11,4 @@ sentencepiece>=0.1.99
h5py~=3.12.1
rouge_score
nltk
datasets==2.14.6
datasets==3.1.0

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
protobuf
rouge_score

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
SentencePiece>=0.1.99

View File

@ -1,7 +1,7 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
transformers>=4.43.0
datasets==2.14.6
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -312,7 +312,8 @@ def main(args):
tokenizer.pad_token = tokenizer.eos_token
dataset_openweb = load_dataset("stas/openwebtext-10k",
cache_dir=args.dataset_path)
cache_dir=args.dataset_path,
trust_remote_code=True)
long_texts = get_long_texts(dataset_openweb) # generator
# get datapoints

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
sentencepiece>=0.1.99
evaluate

View File

@ -1,7 +1,7 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
transformers>=4.39.0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
sentencepiece

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
sentencepiece>=0.1.99
evaluate

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.15.0
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
protobuf
rouge_score

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
protobuf
rouge_score

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
protobuf
rouge_score

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
tiktoken==0.6.0

View File

@ -1,5 +1,5 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.6
datasets==3.1.0
evaluate
rouge_score

View File

@ -1,3 +1,3 @@
datasets~=2.14.6
datasets==3.1.0
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -1,7 +1,7 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
transformers>=4.31.0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -1,5 +1,5 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score

View File

@ -1,5 +1,5 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
evaluate

View File

@ -1,7 +1,7 @@
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets==2.14.6
datasets==3.1.0
evaluate
rouge_score
sentencepiece==0.2.0

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets==2.14.5
datasets==3.1.0
rouge_score
sentencepiece>=0.1.99
evaluate

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
SentencePiece>=0.1.99

View File

@ -1,5 +1,5 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score

View File

@ -1,5 +1,5 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.16.1
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -1,6 +1,6 @@
-c ../../../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets==2.14.6
datasets==3.1.0
evaluate
rouge_score
sentencepiece>=0.1.99

View File

@ -108,6 +108,7 @@ def load_dataset(args) -> datasets.Dataset:
'timeout': aiohttp.ClientTimeout(total=3600)
}
},
trust_remote_code=True,
)
return dataset

View File

@ -2,6 +2,6 @@
tensorrt_llm>=0.0.0.dev0
nemo-toolkit[all]==2.0.0rc1
megatron-core @ git+https://github.com/NVIDIA/Megatron-LM@core_r0.8.0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score

View File

@ -20,7 +20,7 @@ def create_trtllm_magpie_calibration_dataset(output_dir: str,
calib_size: int = 512) -> None:
from datasets import load_dataset
dataset = load_dataset(DATASET, split="train")
dataset = load_dataset(DATASET, split="train", trust_remote_code=True)
def transform(conversation):
value = '\n'.join(turn['value']

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
einops~=0.7.0

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
sentencepiece~=0.1.99
evaluate

View File

@ -1,11 +1,12 @@
import argparse
import json
import os
from typing import Any, Dict, List
from quickstart_advanced import add_llm_args, setup_llm
from transformers import AutoProcessor
from tensorrt_llm.inputs import load_image, load_video
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
default_video_loader)
example_images = [
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
@ -27,92 +28,32 @@ example_video_prompts = [
]
def prepare_vila(args, inputs):
def prepare_multimodal_inputs(model_dir: str,
model_type: str,
modality: str,
prompts: List[str],
media: List[str],
image_data_format: str = "pt",
num_frames: int = 8) -> List[Dict[str, Any]]:
def add_media_token(prompt, multi_modal_data):
mm_tokens = ""
if "image" in multi_modal_data:
for _ in multi_modal_data["image"]:
mm_tokens += "<image>"
elif "video" in multi_modal_data:
for _ in multi_modal_data["video"]:
mm_tokens += "<vila/video>"
return mm_tokens + prompt
inputs = []
if modality == "image":
inputs = default_image_loader(prompts, media, image_data_format)
elif modality == "video":
inputs = default_video_loader(prompts, media, image_data_format,
num_frames)
else:
raise ValueError(f"Unsupported modality: {modality}")
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
for input in inputs:
input["prompt"] = add_media_token(input["prompt"],
input["multi_modal_data"])
return inputs
def prepare_llava_next(args, inputs):
processor = AutoProcessor.from_pretrained(args.model_dir)
# Single-image inference chat template. For multi-image template,
# see https://huggingface.co/docs/transformers/en/model_doc/llava_next#multi-image-inference.
def apply_template(prompt, multimodal_data):
conversation = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image"
},
],
},
]
return processor.apply_chat_template(
conversation,
add_generation_prompt=True,
)
for input in inputs:
input["prompt"] = apply_template(input["prompt"],
input["multi_modal_data"])
return inputs
def prepare_qwen2_vl(args, inputs):
processor = AutoProcessor.from_pretrained(args.model_dir)
def apply_template(prompt, multimodal_data):
content = [{
"type": media_type
} for media_type, items in multimodal_data.items()
for _ in items] + [{
"type": "text",
"text": prompt
}]
conversation = [{"role": "user", "content": content}]
return processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
for input in inputs:
input["prompt"] = apply_template(input["prompt"],
input["multi_modal_data"])
return inputs
MODEL_TYPE_MAP = {
"llava_llama": prepare_vila,
"llava_next": prepare_llava_next,
"qwen2_vl": prepare_qwen2_vl,
"qwen2_5_vl": prepare_qwen2_vl,
}
def add_multimodal_args(parser):
parser.add_argument("--model_type",
type=str,
choices=MODEL_TYPE_MAP.keys(),
choices=INPUT_FORMATTER_MAP.keys(),
help="Model type.")
parser.add_argument("--modality",
type=str,
@ -150,50 +91,16 @@ def main():
llm, sampling_params = setup_llm(args)
image_format = "pt" # ["pt", "pil"]
if args.modality == "image":
prompts = args.prompt if args.prompt else example_image_prompts
images = args.media if args.media else example_images
if len(images) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
images = [images]
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"image": [
load_image(i, format=image_format, device="cuda")
for i in image
] if isinstance(image, list) else
[load_image(image, format=image_format, device="cuda")]
}
} for prompt, image in zip(prompts, images)]
elif args.modality == "video":
prompts = args.prompt if args.prompt else example_video_prompts
videos = args.media if args.media else example_videos
if len(videos) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
videos = [videos]
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"video": [
load_video(
i, args.num_frames, format=image_format, device="cuda")
for i in video
] if isinstance(video, list) else [
load_video(video,
args.num_frames,
format=image_format,
device="cuda")
]
}
} for prompt, video in zip(prompts, videos)]
if args.model_type is not None:
model_type = args.model_type
else:
raise ValueError(f"Unsupported modality: {args.modality}")
model_type = json.load(
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in INPUT_FORMATTER_MAP, f"Unsupported model_type: {model_type}"
model_type = json.load(open(os.path.join(llm._hf_model_dir,
'config.json')))['model_type']
assert model_type in MODEL_TYPE_MAP, f"Unsupported model_type: {model_type}"
inputs = MODEL_TYPE_MAP[model_type](args, inputs)
inputs = prepare_multimodal_inputs(args.model_dir, model_type,
args.modality, args.prompt, args.media,
image_format, args.num_frames)
outputs = llm.generate(inputs, sampling_params)

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets>=2.14.4
datasets==3.1.0
nemo-toolkit[all]==2.0.0rc1
rouge_score
transformers_stream_generator==0.0.4

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.16.0
datasets==3.1.0
evaluate
rouge_score
transformers>=4.40.1

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.dev0
datasets~=2.16.0
datasets==3.1.0
evaluate
rouge_score
transformers>=4.45.0

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.16.0
datasets==3.1.0
evaluate
rouge_score
transformers-stream-generator

View File

@ -5,7 +5,7 @@ flax>=0.8.2
jax~=0.4.23
orbax-checkpoint==0.5.7
transformers>=4.40.0
datasets~=2.14.5
datasets==3.1.0
evaluate
rouge_score
sentencepiece

View File

@ -1,6 +1,6 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
datasets~=2.14.5
datasets==3.1.0
rouge_score
sentencepiece>=0.1.99
evaluate

View File

@ -110,7 +110,8 @@ def main(args):
dataset = load_dataset(dataset_name,
dataset_revision,
cache_dir=args.dataset_cache_dir,
split=dataset_split)
split=dataset_split,
trust_remote_code=True)
dataset = dataset.shuffle(args.random_seed)
max_batch_size = args.batch_size

View File

@ -1,7 +1,7 @@
-c ../constraints.txt
tensorrt_llm>=0.0.0.dev0
tiktoken
datasets
datasets==3.1.0
kaldialign
openai-whisper
librosa

View File

@ -564,7 +564,8 @@ if __name__ == '__main__':
normalizer = EnglishTextNormalizer()
dataset = load_dataset(args.dataset,
args.dataset_name,
split=args.dataset_split)
split=args.dataset_split,
trust_remote_code=True)
if args.enable_warmup:
results, total_duration = decode_dataset(
model,

View File

@ -1,5 +1,4 @@
-r requirements.txt
datasets==2.19.2
einops
graphviz
mypy

View File

@ -31,6 +31,8 @@ pydantic>=2.9.1
pillow==10.3.0
wheel<=0.45.1
optimum
# evaluate needs datasets>=2.0.0 which triggers datasets>3.1.0 which is not stable: https://github.com/huggingface/datasets/issues/7467
datasets==3.1.0
evaluate
mpmath>=1.3.0
click

View File

@ -437,10 +437,7 @@ class Qwen2VLModelBase(PreTrainedModel):
inputs_embeds=input_embeds,
return_context_logits=return_context_logits,
mrope_config=mrope_config)
logger.debug(
f"output_ids: {(output_prob if output_prob.dim() == 2 else output_prob.unsqueeze(0)).argmax(dim=1).tolist()}"
)
logger.info(f'output shape: {output_prob.shape}')
logger.debug(f'output shape: {output_prob.shape}')
return output_prob

View File

@ -10,6 +10,7 @@ from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup,
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
from tensorrt_llm.bench.build.build import get_model_config
# isort: off
from tensorrt_llm.bench.benchmark.utils.general import (
@ -21,7 +22,8 @@ from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
from tensorrt_llm.bench.dataclasses.reporting import ReportUtility
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
initialize_tokenizer)
initialize_tokenizer,
update_metadata_for_multimodal)
from tensorrt_llm.llmapi import LLM, CapacitySchedulerPolicy
from tensorrt_llm.logger import logger
from tensorrt_llm.sampling_params import SamplingParams
@ -92,6 +94,20 @@ from tensorrt_llm.sampling_params import SamplingParams
required=False,
help="Pass in a dataset file for parsing instead of stdin.",
)
@optgroup.option(
"--modality",
type=click.Choice(["image", "video"]),
default=None,
help="Modality of the multimodal requests.",
)
@optgroup.option(
"--max_input_len",
type=int,
default=4096,
help=
"Maximum input sequence length to use for multimodal models. This is used only when --modality "
"is specified since the actual number of vision tokens is unknown before the model is run.",
)
@optgroup.option(
"--num_requests",
type=int,
@ -194,6 +210,9 @@ def throughput_command(
engine_dir: Path = params.pop("engine_dir")
concurrency: int = params.pop("concurrency")
backend: str = params.get("backend")
modality: str = params.pop("modality")
max_input_len: int = params.pop("max_input_len")
model_type = get_model_config(model, checkpoint_path).model_type
# Reporting options
report_json: Path = params.pop("report_json")
@ -209,14 +228,23 @@ def throughput_command(
# Dataset Loading and Preparation
with open(dataset_path, "r") as dataset:
metadata, requests = create_dataset_from_stream(
tokenizer, dataset, num_requests=num_requests)
tokenizer,
dataset,
num_requests=num_requests,
model_dir=checkpoint_path,
model_type=model_type,
modality=modality,
max_input_seq_len_for_multimodal=max_input_len)
metadata.dataset_path = dataset_path
params["target_input_len"] = params.get(
"target_input_len") or metadata.avg_isl
params["target_output_len"] = params.get(
"target_output_len") or metadata.avg_osl
if modality is None:
# Log dataset info
# NOTE: This table is only accurate for non-multimodal models.
# The accurate table for multimodal models will be logged after the benchmark is done.
logger.info(metadata.get_summary_for_print())
# Engine configuration parsing
@ -294,8 +322,12 @@ def throughput_command(
warmup_dataset = generate_warmup_dataset(requests, warmup)
logger.info("Running warmup.")
asyncio.run(
async_benchmark(llm, sampling_params, warmup_dataset, False,
concurrency))
async_benchmark(llm,
sampling_params,
warmup_dataset,
False,
concurrency,
modality=modality))
# WAR: IterationResult is a singleton tied to the executor.
# Since the benchmark calls asyncio.run() multiple times (e.g., during warmup),
# we must reset it to ensure it attaches to the correct event loop.
@ -304,10 +336,19 @@ def throughput_command(
with iteration_writer.capture():
statistics = asyncio.run(
async_benchmark(llm, sampling_params, requests, streaming,
concurrency, iteration_writer.full_address))
async_benchmark(llm,
sampling_params,
requests,
streaming,
concurrency,
iteration_writer.full_address,
modality=modality))
logger.info(f"Benchmark done. Reporting results...")
if modality is not None:
# For multimodal models, we need to update the metadata with the correct input lengths
metadata = update_metadata_for_multimodal(metadata, statistics)
report_utility = ReportUtility(statistics, metadata, runtime_config,
logger, kwargs, streaming)
if report_json:

View File

@ -23,7 +23,8 @@ class LlmManager:
llm: LLM,
outbox: asyncio.Queue[PerfItemTuple],
streaming: bool,
concurrency: int = -1) -> None:
concurrency: int = -1,
modality: Optional[str] = None) -> None:
self.llm = llm
self._inbox: asyncio.Queue[Tuple[InferenceRequest,
SamplingParams]] = asyncio.Queue()
@ -38,6 +39,7 @@ class LlmManager:
concurrency) if concurrency > 0 else None
self.streaming = streaming
self.request_seen = asyncio.Event()
self.modality = modality
async def process_request(self, request: InferenceRequest,
sampling_params: SamplingParams):
@ -50,7 +52,7 @@ class LlmManager:
time_on_first_token = None
# Schedule the request in the LLM API (asynchronously)
output: RequestOutput = self.llm.generate_async(
request.input_ids,
request.input_ids if self.modality is None else request.prompt,
sampling_params=sampling_params,
streaming=self.streaming)
if self.streaming:
@ -70,7 +72,7 @@ class LlmManager:
start_timestamp=request_start_timestamp,
end_timestamp=response_end_timestamp,
request_id=response.request_id,
num_input_tokens=len(request.input_ids),
num_input_tokens=len(output.prompt_token_ids),
response_is_final=response.finished,
error=False,
tokens=tokens,
@ -201,6 +203,7 @@ async def async_benchmark(
streaming: bool,
concurrency: int = -1,
iteration_log_addr: str = None,
modality: Optional[str] = None,
) -> StatsKeeper:
outbox = asyncio.Queue()
statistics = StatsKeeper()
@ -208,7 +211,11 @@ async def async_benchmark(
try:
logger.info("Starting benchmarking async task.")
backend = LlmManager(llm, outbox, streaming, concurrency=concurrency)
backend = LlmManager(llm,
outbox,
streaming,
concurrency=concurrency,
modality=modality)
backend.run(iteration_addr=iteration_log_addr)
enqueue_task = asyncio.create_task(

View File

@ -116,6 +116,7 @@ class ModelConfig(BaseModel):
setting calculation.
"""
name: str
model_type: str
param_count: int
num_hidden_layers: int = Field(validation_alias=AliasChoices(
"num_hidden_layers",

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
from typing import Any, List, Optional, Union
from pydantic import (AliasChoices, BaseModel, Field, computed_field,
model_validator)
@ -17,7 +17,7 @@ class BenchmarkEnvironment(BaseModel):
class InferenceRequest(BaseModel):
task_id: int
prompt: Optional[str] = None
prompt: Optional[Union[str, Any]] = None
output_tokens: int
input_ids: Optional[List[int]] = Field(
alias=AliasChoices("input_ids", "logits"))

View File

@ -1,12 +1,36 @@
import json
from functools import partial
from typing import List, TextIO, Tuple
from typing import Any, Dict, List, TextIO, Tuple
from transformers import AutoTokenizer, PreTrainedTokenizer
from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata,
InferenceRequest)
from tensorrt_llm.bench.dataclasses.statistics import PercentileStats
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
default_video_loader)
def prepare_multimodal_inputs(model_dir: str,
model_type: str,
modality: str,
prompts: List[str],
media: List[str],
image_data_format: str = "pt",
num_frames: int = 8) -> List[Dict[str, Any]]:
inputs = []
if modality == "image":
inputs = default_image_loader(prompts, media, image_data_format)
elif modality == "video":
inputs = default_video_loader(prompts, media, image_data_format,
num_frames)
else:
raise ValueError(f"Unsupported modality: {modality}")
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
return inputs
def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer:
@ -36,6 +60,10 @@ def create_dataset_from_stream(
max_input_length: int = 0,
max_output_length: int = 0,
num_requests: int = 0,
model_dir: str = None,
model_type: str = None,
modality: str = None,
max_input_seq_len_for_multimodal: int = 4096,
) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
"""Generate metadata and a list of requests to drive benchmarking.
@ -83,13 +111,30 @@ def create_dataset_from_stream(
# Each line should be a complete JSON dictionary with no indentation
# or newline characters.
data = json.loads(line)
if modality is not None:
# Multimodal data
assert modality in [
"image", "video"
], f"Modality must be one of ['image', 'video'] but got {modality}."
prompt = data.get("prompt") # cannot be None
media_paths = data.get("media_paths", None)
inputs = prepare_multimodal_inputs(
model_dir,
model_type,
modality,
prompts=[prompt],
media=media_paths) # list of dicts
logits = None # cannot tokenize multi-modal data, handled by preprocessor
prompt = inputs[0]
else:
logits = data.get("input_ids", data.get("logits", None))
prompt = data.get("prompt", None)
task_id = data["task_id"]
osl = data["output_tokens"]
# If the request comes in with logits, just use the provided.
# Otherwise we need to tokenize it.
logits = tokenize(prompt)["input_ids"] if logits is None else logits
task_id = data["task_id"]
osl = data["output_tokens"]
request = InferenceRequest(
task_id=task_id,
@ -97,8 +142,13 @@ def create_dataset_from_stream(
output_tokens=output_limiter(osl),
input_ids=logits,
)
all_isl.append(len(logits))
all_osl.append(osl)
if modality is not None:
cur_isl = max_input_seq_len_for_multimodal # NOTE: actual sequence length is unknown until the model is run
all_isl.append(cur_isl)
all_seq_len.append(cur_isl + osl)
else:
all_isl.append(len(logits))
all_seq_len.append(len(logits) + osl)
dataset.append(request)
@ -115,3 +165,31 @@ def create_dataset_from_stream(
)
return metadata, dataset
def update_metadata_for_multimodal(metadata, statistics) -> DatasetMetadata:
"""Update the metadata from benchmark statistics. Only used for multimodal models.
Args:
metadata (DatasetMetadata): The metadata to update.
statistics (StatsKeeper): The statistics to update the metadata with.
Returns:
DatasetMetadata: The updated metadata.
"""
all_isl = []
all_osl = []
all_seq_len = []
for request in statistics.requests.values():
all_isl.append(request.num_input_tokens)
all_osl.append(request.num_total_output_tokens)
all_seq_len.append(request.num_input_tokens +
request.num_total_output_tokens)
isl_stats = PercentileStats.from_iterable(all_isl)
osl_stats = PercentileStats.from_iterable(all_osl)
seq_len_stats = PercentileStats.from_iterable(all_seq_len)
metadata.isl_stats = isl_stats
metadata.osl_stats = osl_stats
metadata.seq_len_stats = seq_len_stats
return metadata

View File

@ -36,7 +36,10 @@ class CnnDailymail(Evaluator):
system_prompt: Optional[str] = None):
super().__init__(apply_chat_template=apply_chat_template,
system_prompt=system_prompt)
self.data = datasets.load_dataset(dataset_path, "3.0.0", split="test")
self.data = datasets.load_dataset(dataset_path,
"3.0.0",
split="test",
trust_remote_code=True)
self.data = self.data.shuffle(random_seed)
if num_samples is None:
self.num_samples = self.data.num_rows

View File

@ -1,10 +1,15 @@
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
from .registry import (ExtraProcessedInputs, InputProcessor,
create_input_processor, register_input_processor)
from .utils import load_image, load_video
from .utils import (INPUT_FORMATTER_MAP, default_image_loader,
default_video_loader, format_llava_next_input,
format_qwen2_vl_input, format_vila_input, load_image,
load_video)
__all__ = [
"PromptInputs", "prompt_inputs", "TextPrompt", "TokensPrompt",
"InputProcessor", "create_input_processor", "register_input_processor",
"ExtraProcessedInputs", "load_image", "load_video"
"ExtraProcessedInputs", "load_image", "load_video", "INPUT_FORMATTER_MAP",
"default_image_loader", "default_video_loader", "format_vila_input",
"format_llava_next_input", "format_qwen2_vl_input"
]

View File

@ -6,6 +6,7 @@ import requests
import torch
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import AutoProcessor
def load_image(image: str,
@ -67,3 +68,158 @@ def load_video(
device=device) if format == "pt" else frames[index]
for index in indices if index in frames
]
"""
VLM input preparation.
"""
def format_vila_input(model_dir, inputs):
"""
This function formats the input for the VILA/NVILA VL model.
Arguments:
model_dir: The directory of the model to load any preprocessor.
inputs: The list of inputs to format.
Returns:
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
"""
def add_media_token(prompt, multi_modal_data):
mm_tokens = ""
if "image" in multi_modal_data:
for _ in multi_modal_data["image"]:
mm_tokens += "<image>"
elif "video" in multi_modal_data:
for _ in multi_modal_data["video"]:
mm_tokens += "<vila/video>"
return mm_tokens + prompt
for input in inputs:
input["prompt"] = add_media_token(input["prompt"],
input["multi_modal_data"])
return inputs
def format_llava_next_input(model_dir, inputs):
"""
This function formats the input for the Llava Next VL model.
Arguments:
model_dir: The directory of the model to load any preprocessor.
inputs: The list of inputs to format.
Returns:
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
"""
processor = AutoProcessor.from_pretrained(model_dir)
# Single-image inference chat template. For multi-image template,
# see https://huggingface.co/docs/transformers/en/model_doc/llava_next#multi-image-inference.
def apply_template(prompt, multimodal_data):
conversation = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image"
},
],
},
]
return processor.apply_chat_template(
conversation,
add_generation_prompt=True,
)
for input in inputs:
input["prompt"] = apply_template(input["prompt"],
input["multi_modal_data"])
return inputs
def format_qwen2_vl_input(model_dir, inputs):
"""
This function formats the input for the Qwen2/Qwen2.5 VL model.
Arguments:
model_dir: The directory of the model to load any preprocessor.
inputs: The list of inputs to format.
Returns:
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
"""
processor = AutoProcessor.from_pretrained(model_dir)
def apply_template(prompt, multimodal_data):
content = [{
"type": media_type
} for media_type, items in multimodal_data.items()
for _ in items] + [{
"type": "text",
"text": prompt
}]
conversation = [{"role": "user", "content": content}]
# print(conversation)
return processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
for input in inputs:
input["prompt"] = apply_template(input["prompt"],
input["multi_modal_data"])
return inputs
def default_image_loader(prompts, images, image_data_format="pt"):
if len(images) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
images = [images]
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"image": [
load_image(i, format=image_data_format, device="cuda")
for i in image
] if isinstance(image, list) else
[load_image(image, format=image_data_format, device="cuda")]
}
} for prompt, image in zip(prompts, images)]
return inputs
def default_video_loader(prompts, videos, image_data_format="pt", num_frames=8):
if len(videos) > len(prompts) and len(prompts) == 1:
# 1 prompt + N media
videos = [videos]
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"video": [
load_video(
i, num_frames, format=image_data_format, device="cuda")
for i in video
] if isinstance(video, list) else [
load_video(
video, num_frames, format=image_data_format, device="cuda")
]
}
} for prompt, video in zip(prompts, videos)]
return inputs
INPUT_FORMATTER_MAP = {
"llava_llama": format_vila_input,
"llava_next": format_llava_next_input,
"qwen2_vl": format_qwen2_vl_input,
"qwen2_5_vl": format_qwen2_vl_input,
}

View File

@ -306,6 +306,7 @@ def load_calib_dataset(dataset_name_or_dir: str,
dataset = load_dataset(dataset_name_or_dir,
name=config_name,
split=split,
trust_remote_code=trust_remote_code,
**kwargs)
return dataset[key]

View File

@ -384,20 +384,26 @@ def get_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
dataset = load_dataset(
"json",
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
split="train")
split="train",
trust_remote_code=True)
dataset = dataset["text"][:calib_size]
elif "scienceqa" in dataset_name_or_dir.lower(
) or "science_qa" in dataset_name_or_dir.lower():
if os.path.isdir(dataset_name_or_dir):
dataset = load_dataset(dataset_name_or_dir, split="train")
dataset = load_dataset(dataset_name_or_dir,
split="train",
trust_remote_code=True)
else:
dataset = load_dataset("derek-thomas/ScienceQA", split="train")
dataset = load_dataset("derek-thomas/ScienceQA",
split="train",
trust_remote_code=True)
dataset = dataset.select(range(calib_size))
elif "cnn_dailymail" in dataset_name_or_dir:
dataset = load_dataset(
dataset_name_or_dir,
name="3.0.0",
split="train",
trust_remote_code=True,
)
dataset = dataset["article"][:calib_size]
elif os.path.isdir(dataset_name_or_dir):
@ -405,7 +411,9 @@ def get_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
"assuming the calibration data are in the train split and text column."
)
dataset = load_dataset(dataset_name_or_dir, split="train")
dataset = load_dataset(dataset_name_or_dir,
split="train",
trust_remote_code=True)
dataset = dataset["text"][:calib_size]
else:
raise NotImplementedError(
@ -993,22 +1001,29 @@ def get_nemo_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
dataset = load_dataset(
"json",
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
split="train")
split="train",
trust_remote_code=True)
text_column = "text"
elif "wikitext" in dataset_name_or_dir:
dataset = load_dataset(dataset_name_or_dir,
"wikitext-103-v1",
split="train")
split="train",
trust_remote_code=True)
text_column = "text"
elif "cnn_dailymail" in dataset_name_or_dir:
dataset = load_dataset(dataset_name_or_dir, name="3.0.0", split="train")
dataset = load_dataset(dataset_name_or_dir,
name="3.0.0",
split="train",
trust_remote_code=True)
text_column = "article"
elif os.path.isdir(dataset_name_or_dir):
logger.info(
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
"assuming the calibration data are in the train split and text column."
)
dataset = load_dataset(dataset_name_or_dir, split="train")
dataset = load_dataset(dataset_name_or_dir,
split="train",
trust_remote_code=True)
text_column = "text"
else:
raise NotImplementedError(

View File

@ -366,8 +366,11 @@ def test_tokenizer_decode_incrementally(tokenizer_dir: str, threshold: float):
num_samples = 100
cnn_dailymail = datasets.load_dataset(cnn_dailymail_path,
name='3.0.0',
split='train')
alpaca_chinese = datasets.load_dataset(alpaca_chinese_path, split='train')
split='train',
trust_remote_code=True)
alpaca_chinese = datasets.load_dataset(alpaca_chinese_path,
split='train',
trust_remote_code=True)
dataset = cnn_dailymail['article'][:num_samples // 2] + alpaca_chinese[
'output_zh'][:num_samples // 2]