mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
feat: adding multimodal (only image for now) support in trtllm-bench (#3490)
* feat: adding multimodal (only image for now) support in trtllm-bench Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * fix: add in load_dataset() calls to maintain the v2.19.2 behavior Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * re-adding prompt_token_ids and using that for prompt_len Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * updating the datasets version in examples as well Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * api changes are not needed Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * moving datasets requirement and removing a missed api change Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * addressing review comments Signed-off-by: Rakib Hasan <rhasan@nvidia.com> * refactoring the quickstart example Signed-off-by: Rakib Hasan <rhasan@nvidia.com> --------- Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
parent
26ebd95302
commit
ff3b741045
@ -1,12 +1,16 @@
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, model_validator
|
||||
from utils.utils import dataset_dump, get_norm_dist_lengths, print_dataset
|
||||
from utils.utils import (get_norm_dist_lengths, multimodal_dataset_dump,
|
||||
print_multimodal_dataset, print_text_dataset,
|
||||
text_dataset_dump)
|
||||
|
||||
|
||||
def validate_output_len_dist(ctx, param, value):
|
||||
@ -31,8 +35,10 @@ class DatasetConfig(BaseModel):
|
||||
"""Split of the dataset. Typical values: train, validation, test. Setting to None will include all splits."""
|
||||
split: Optional[str]
|
||||
"""The dataset dictionary used for the input sentence."""
|
||||
input_key: str
|
||||
input_key: Optional[str] = None
|
||||
"""The dataset dictionary key used for the prompt of the input sentence. Must not be set when prompt is set."""
|
||||
image_key: Optional[str] = None
|
||||
"""The dataset dictionary key used for the images."""
|
||||
prompt_key: Optional[str] = None
|
||||
"""The prompt sentence to be added to the input sentence. Must not be set when prompt_key is set."""
|
||||
prompt: Optional[str] = None
|
||||
@ -75,6 +81,20 @@ class DatasetConfig(BaseModel):
|
||||
f"{req.keys()}")
|
||||
return req[self.input_key]
|
||||
|
||||
def get_images(self, req):
|
||||
"""Get the images from the given request."""
|
||||
image_keys = [self.image_key
|
||||
] + [f"{self.image_key}_{i}" for i in range(1, 8)]
|
||||
assert any(key in req for key in image_keys), (
|
||||
f"Dataset {self.name} does not have key '{self.image_key}'. "
|
||||
"Please set --dataset-image-key to one of the available keys: "
|
||||
f"{req.keys()}")
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if key in req and req[key] is not None:
|
||||
images.append(req[key])
|
||||
return images
|
||||
|
||||
def get_output(self, req):
|
||||
"""Get the output sentence from the given request."""
|
||||
if self.output_key is None:
|
||||
@ -105,7 +125,8 @@ def load_dataset_from_hf(dataset_config: DatasetConfig):
|
||||
dataset = iter(
|
||||
load_dataset(*dataset_config.query,
|
||||
split=dataset_config.split,
|
||||
streaming=True))
|
||||
streaming=True,
|
||||
trust_remote_code=True))
|
||||
except ValueError as e:
|
||||
if "Config" in e:
|
||||
e += "\n Please add the config name to the dataset config yaml."
|
||||
@ -130,9 +151,12 @@ def load_dataset_from_hf(dataset_config: DatasetConfig):
|
||||
required=True,
|
||||
help=f"Split of the dataset to use.")
|
||||
@click.option("--dataset-input-key",
|
||||
required=True,
|
||||
type=str,
|
||||
help=f"The dataset dictionary key for input.")
|
||||
@click.option("--dataset-image-key",
|
||||
type=str,
|
||||
default="image",
|
||||
help=f"The dataset dictionary key for images.")
|
||||
@click.option("--dataset-prompt-key",
|
||||
type=str,
|
||||
default=None,
|
||||
@ -181,8 +205,39 @@ def dataset(root_args, **kwargs):
|
||||
output_lens = []
|
||||
task_ids = []
|
||||
req_cnt = 0
|
||||
modality = None
|
||||
multimodal_texts = []
|
||||
multimodal_image_paths = []
|
||||
for req in load_dataset_from_hf(dataset_config):
|
||||
# input
|
||||
if any(key in req for key in ['image', 'image_1', 'video']):
|
||||
# multimodal input
|
||||
if 'video' in req and req['video'] is not None:
|
||||
assert "Not supported yet"
|
||||
assert kwargs['output_len_dist'] is not None, (
|
||||
"Output length distribution must be set for multimodal requests."
|
||||
)
|
||||
modality = 'image'
|
||||
text = dataset_config.get_prompt(req)
|
||||
images = dataset_config.get_images(req)
|
||||
image_paths = []
|
||||
for image in images:
|
||||
if image is not None:
|
||||
if isinstance(image, str):
|
||||
image_paths.append(image)
|
||||
elif isinstance(image, Image.Image):
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".jpg", delete=False) as tmp_file:
|
||||
logging.debug(f"Saving image to {tmp_file.name}")
|
||||
image = image.convert("RGB")
|
||||
image.save(tmp_file, "JPEG")
|
||||
filepath = tmp_file.name
|
||||
image_paths.append(filepath)
|
||||
else:
|
||||
raise ValueError(f"Invalid image path: {image}")
|
||||
multimodal_texts.append(text)
|
||||
multimodal_image_paths.append(image_paths)
|
||||
else:
|
||||
# text input
|
||||
prompt = dataset_config.get_prompt(
|
||||
req) + ' ' + dataset_config.get_input(req)
|
||||
logging.debug(f"Input sequence: {prompt}")
|
||||
@ -195,7 +250,9 @@ def dataset(root_args, **kwargs):
|
||||
# output if fetch from golden
|
||||
if kwargs['output_len_dist'] is None:
|
||||
output_lens.append(
|
||||
len(root_args.tokenizer.encode(dataset_config.get_output(req))))
|
||||
len(
|
||||
root_args.tokenizer.encode(
|
||||
dataset_config.get_output(req))))
|
||||
|
||||
# lora task id
|
||||
task_id = root_args.task_id
|
||||
@ -208,21 +265,44 @@ def dataset(root_args, **kwargs):
|
||||
if kwargs['num_requests'] and req_cnt >= kwargs['num_requests']:
|
||||
break
|
||||
|
||||
if kwargs['num_requests'] and len(input_ids) < kwargs['num_requests']:
|
||||
if kwargs['num_requests'] and (len(input_ids) if modality is None else len(
|
||||
multimodal_texts)) < kwargs['num_requests']:
|
||||
logging.warning(
|
||||
"Number of requests is smaller than the num-requests user set.")
|
||||
f"Number of requests={len(input_ids) if modality is None else len(multimodal_texts)} is"
|
||||
f" smaller than the num-requests user set={kwargs['num_requests']}."
|
||||
)
|
||||
|
||||
# output if randomized
|
||||
if kwargs['output_len_dist'] is not None:
|
||||
osl_mean, osl_stdev = kwargs['output_len_dist']
|
||||
output_lens = get_norm_dist_lengths(osl_mean, osl_stdev, len(input_ids),
|
||||
output_lens = get_norm_dist_lengths(
|
||||
osl_mean, osl_stdev,
|
||||
len(input_ids) if modality is None else len(multimodal_texts),
|
||||
root_args.random_seed)
|
||||
|
||||
logging.debug(f"Input lengths: {[len(i) for i in input_ids]}")
|
||||
logging.debug(f"Output lengths: {output_lens}")
|
||||
if modality is not None:
|
||||
logging.debug(f"Modality: {modality}")
|
||||
|
||||
if modality is not None:
|
||||
if not root_args.std_out:
|
||||
dataset_dump(
|
||||
multimodal_dataset_dump(
|
||||
multimodal_texts, multimodal_image_paths, output_lens, task_ids,
|
||||
{
|
||||
"workload_type": "dataset",
|
||||
"tokenizer": root_args.tokenizer.__class__.__name__,
|
||||
"num_requests": len(task_ids),
|
||||
"max_output_len": max(output_lens)
|
||||
}, root_args.output)
|
||||
else:
|
||||
print_multimodal_dataset(
|
||||
multimodal_texts,
|
||||
multimodal_image_paths,
|
||||
output_lens,
|
||||
)
|
||||
else:
|
||||
if not root_args.std_out:
|
||||
text_dataset_dump(
|
||||
input_lens, input_ids, output_lens, task_ids, {
|
||||
"workload_type": "dataset",
|
||||
"tokenizer": root_args.tokenizer.__class__.__name__,
|
||||
@ -231,7 +311,7 @@ def dataset(root_args, **kwargs):
|
||||
"max_output_len": max(output_lens)
|
||||
}, root_args.output)
|
||||
else:
|
||||
print_dataset(
|
||||
print_text_dataset(
|
||||
input_ids,
|
||||
output_lens,
|
||||
)
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import random
|
||||
|
||||
import click
|
||||
from utils.utils import (dataset_dump, gen_random_tokens, get_norm_dist_lengths,
|
||||
get_unif_dist_lengths, print_dataset)
|
||||
from utils.utils import (gen_random_tokens, get_norm_dist_lengths,
|
||||
get_unif_dist_lengths, print_text_dataset,
|
||||
text_dataset_dump)
|
||||
|
||||
|
||||
@click.command()
|
||||
@ -57,7 +58,7 @@ def token_norm_dist(root_args, **kwargs):
|
||||
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
|
||||
|
||||
if not root_args.std_out:
|
||||
dataset_dump(
|
||||
text_dataset_dump(
|
||||
input_lens, input_ids, output_lens, task_ids, {
|
||||
"workload_type": "token-norm-dist",
|
||||
"input_mean": kwargs['input_mean'],
|
||||
@ -70,7 +71,7 @@ def token_norm_dist(root_args, **kwargs):
|
||||
"max_output_len": max_output_len
|
||||
}, root_args.output)
|
||||
else:
|
||||
print_dataset(
|
||||
print_text_dataset(
|
||||
input_ids,
|
||||
output_lens,
|
||||
)
|
||||
@ -127,7 +128,7 @@ def token_unif_dist(root_args, **kwargs):
|
||||
task_ids = [random.randint(min_id, max_id) for _ in range(num_reqs)]
|
||||
|
||||
if not root_args.std_out:
|
||||
dataset_dump(
|
||||
text_dataset_dump(
|
||||
input_lens, input_ids, output_lens, task_ids, {
|
||||
"workload_type": "token-unif-dist",
|
||||
"input_min": kwargs['input_min'],
|
||||
@ -140,7 +141,7 @@ def token_unif_dist(root_args, **kwargs):
|
||||
"max_output_len": max_output_len
|
||||
}, root_args.output)
|
||||
else:
|
||||
print_dataset(
|
||||
print_text_dataset(
|
||||
input_ids,
|
||||
output_lens,
|
||||
)
|
||||
|
||||
@ -2,22 +2,29 @@ import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Sample(BaseModel):
|
||||
class TextSample(BaseModel):
|
||||
input_len: int
|
||||
input_ids: List[int]
|
||||
output_len: int
|
||||
task_id: int
|
||||
|
||||
|
||||
class MultimodalSample(BaseModel):
|
||||
task_id: int
|
||||
prompt: str
|
||||
media_paths: List[str]
|
||||
output_len: int
|
||||
|
||||
|
||||
class Workload(BaseModel):
|
||||
metadata: dict
|
||||
samples: List[Sample] = []
|
||||
samples: List[Union[TextSample, MultimodalSample]] = []
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@ -33,12 +40,12 @@ class Workload(BaseModel):
|
||||
self.metadata.setdefault('workload_name', workload_name)
|
||||
|
||||
|
||||
def dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
|
||||
def text_dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
|
||||
output_file):
|
||||
samples = []
|
||||
for i in range(len(input_ids)):
|
||||
samples.append(
|
||||
Sample(input_len=input_lens[i],
|
||||
TextSample(input_len=input_lens[i],
|
||||
input_ids=input_ids[i],
|
||||
output_len=output_lens[i],
|
||||
task_id=task_ids[i]))
|
||||
@ -48,7 +55,22 @@ def dataset_dump(input_lens, input_ids, output_lens, task_ids, metadata,
|
||||
json.dump(workload.model_dump(), f)
|
||||
|
||||
|
||||
def print_dataset(input_ids, output_lens):
|
||||
def multimodal_dataset_dump(multimodal_texts, multimodal_image_paths,
|
||||
output_lens, task_ids, metadata, output_file):
|
||||
samples = []
|
||||
for i in range(len(multimodal_texts)):
|
||||
samples.append(
|
||||
MultimodalSample(task_id=task_ids[i],
|
||||
prompt=multimodal_texts[i],
|
||||
media_paths=multimodal_image_paths[i],
|
||||
output_len=output_lens[i]))
|
||||
workload = Workload(metadata=metadata, samples=samples)
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, 'w') as f:
|
||||
json.dump(workload.model_dump(), f)
|
||||
|
||||
|
||||
def print_text_dataset(input_ids, output_lens):
|
||||
for i, input_tokens in enumerate(input_ids):
|
||||
d = {
|
||||
"task_id": i,
|
||||
@ -58,6 +80,19 @@ def print_dataset(input_ids, output_lens):
|
||||
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
|
||||
|
||||
|
||||
def print_multimodal_dataset(multimodal_texts, multimodal_image_paths,
|
||||
output_lens):
|
||||
for i, (text, image_paths) in enumerate(
|
||||
zip(multimodal_texts, multimodal_image_paths)):
|
||||
d = {
|
||||
"task_id": i,
|
||||
"prompt": text,
|
||||
"media_paths": image_paths,
|
||||
"output_tokens": output_lens[i]
|
||||
}
|
||||
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
|
||||
|
||||
|
||||
def get_list_of_delays(delay_dist, mean_time_bet_reqs, num_reqs, random_seed):
|
||||
if delay_dist == "constant":
|
||||
delays = [mean_time_bet_reqs] * num_reqs
|
||||
|
||||
@ -475,6 +475,115 @@ Total Latency (ms): 18563.6825
|
||||
|
||||
```
|
||||
|
||||
#### Running multi-modal models in the PyTorch Workflow
|
||||
|
||||
To benchmark multi-modal models with PyTorch workflow, you can follow the similar approach as above.
|
||||
|
||||
First, prepare the dataset:
|
||||
```
|
||||
python ./benchmarks/cpp/prepare_dataset.py \
|
||||
--tokenizer Qwen/Qwen2-VL-2B-Instruct \
|
||||
--stdout \
|
||||
dataset \
|
||||
--dataset-name lmms-lab/MMMU \
|
||||
--dataset-split test \
|
||||
--dataset-image-key image \
|
||||
--dataset-prompt-key question \
|
||||
--num-requests 10 \
|
||||
--output-len-dist 128,5 > mm_data.jsonl
|
||||
```
|
||||
It will download the media files to `/tmp` directory and prepare the dataset with their paths. Note that the `prompt` fields are texts and not tokenized ids. This is due to the fact that
|
||||
the `prompt` and the media (image/video) are processed by a preprocessor for multimodal files.
|
||||
|
||||
Sample dataset for multimodal:
|
||||
```
|
||||
{"task_id":0,"prompt":"Brahma Industries sells vinyl replacement windows to home improvement retailers nationwide. The national sales manager believes that if they invest an additional $25,000 in advertising, they would increase sales volume by 10,000 units. <image 1> What is the total contribution margin?","media_paths":["/tmp/tmp9so41y3r.jpg"],"output_tokens":126}
|
||||
{"task_id":1,"prompt":"Let us compute for the missing amounts under work in process inventory, what is the cost of goods manufactured? <image 1>","media_paths":["/tmp/tmpowsrb_f4.jpg"],"output_tokens":119}
|
||||
{"task_id":2,"prompt":"Tsuji is reviewing the price of a 3-month Japanese yen/U.S. dollar currency futures contract, using the currency and interest rate data shown below. Because the 3-month Japanese interest rate has just increased to .50%, Itsuji recognizes that an arbitrage opportunity exists nd decides to borrow $1 million U.S. dollars to purchase Japanese yen. Calculate the yen arbitrage profit from Itsuji's strategy, using the following data: <image 1> ","media_paths":["/tmp/tmpxhdvasex.jpg"],"output_tokens":126}
|
||||
...
|
||||
```
|
||||
|
||||
Run the benchmark:
|
||||
```
|
||||
trtllm-bench --model Qwen/Qwen2-VL-2B-Instruct \
|
||||
throughput \
|
||||
--dataset mm_data.jsonl \
|
||||
--backend pytorch \
|
||||
--num_requests 10 \
|
||||
--max_batch_size 4 \
|
||||
--modality image
|
||||
```
|
||||
|
||||
|
||||
Sample output:
|
||||
```
|
||||
===========================================================
|
||||
= REQUEST DETAILS
|
||||
===========================================================
|
||||
Number of requests: 10
|
||||
Number of concurrent requests: 5.3019
|
||||
Average Input Length (tokens): 411.6000
|
||||
Average Output Length (tokens): 128.7000
|
||||
===========================================================
|
||||
= WORLD + RUNTIME INFORMATION
|
||||
===========================================================
|
||||
TP Size: 1
|
||||
PP Size: 1
|
||||
EP Size: None
|
||||
Max Runtime Batch Size: 4
|
||||
Max Runtime Tokens: 12288
|
||||
Scheduling Policy: GUARANTEED_NO_EVICT
|
||||
KV Memory Percentage: 90.00%
|
||||
Issue Rate (req/sec): 1.4117E+17
|
||||
|
||||
===========================================================
|
||||
= PERFORMANCE OVERVIEW
|
||||
===========================================================
|
||||
Request Throughput (req/sec): 1.4439
|
||||
Total Output Throughput (tokens/sec): 185.8351
|
||||
Per User Output Throughput (tokens/sec/user): 38.1959
|
||||
Per GPU Output Throughput (tokens/sec/gpu): 185.8351
|
||||
Total Token Throughput (tokens/sec): 780.1607
|
||||
Total Latency (ms): 6925.4963
|
||||
Average request latency (ms): 3671.8441
|
||||
|
||||
-- Request Latency Breakdown (ms) -----------------------
|
||||
|
||||
[Latency] P50 : 3936.3022
|
||||
[Latency] P90 : 5514.4701
|
||||
[Latency] P95 : 5514.4701
|
||||
[Latency] P99 : 5514.4701
|
||||
[Latency] MINIMUM: 2397.1047
|
||||
[Latency] MAXIMUM: 5514.4701
|
||||
[Latency] AVERAGE: 3671.8441
|
||||
|
||||
===========================================================
|
||||
= DATASET DETAILS
|
||||
===========================================================
|
||||
Dataset Path: /workspaces/tensorrt_llm/mm_data.jsonl
|
||||
Number of Sequences: 10
|
||||
|
||||
-- Percentiles statistics ---------------------------------
|
||||
|
||||
Input Output Seq. Length
|
||||
-----------------------------------------------------------
|
||||
MIN: 167.0000 119.0000 300.0000
|
||||
MAX: 1059.0000 137.0000 1178.0000
|
||||
AVG: 411.6000 128.7000 540.3000
|
||||
P50: 299.0000 128.0000 427.0000
|
||||
P90: 1059.0000 137.0000 1178.0000
|
||||
P95: 1059.0000 137.0000 1178.0000
|
||||
P99: 1059.0000 137.0000 1178.0000
|
||||
===========================================================
|
||||
```
|
||||
|
||||
**Notes and Limitations**:
|
||||
- Only image datasets are supported for now.
|
||||
- `--output-len-dist` is a required argument for multimodal datasets.
|
||||
- Tokenizer is unused during the prepare step but it is still a required argument.
|
||||
- Since the images are converted to tokens when the model is run, `trtllm-bench` uses a default large value for the maximum input sequence length when setting up the execution settings.
|
||||
You can also modify the behavior by specifying a different value with the flag `--max_input_len` that suits your use-case.
|
||||
|
||||
#### Quantization in the PyTorch Flow
|
||||
|
||||
In order to run a quantized run with `trtllm-bench` utilizing the PyTorch flow, you will need to use a pre-quantized
|
||||
|
||||
@ -17,7 +17,7 @@ def prepare_text_inputs(model_name, batch_size=8):
|
||||
f"HF_DATASETS_OFFLINE inside function: {datasets.config.HF_DATASETS_OFFLINE}"
|
||||
)
|
||||
if model_name == "BertForQuestionAnswering" or model_name == "RobertaForQuestionAnswering":
|
||||
squad_dataset = load_dataset("squad_v2")
|
||||
squad_dataset = load_dataset("squad_v2", trust_remote_code=True)
|
||||
val_dataset = squad_dataset["validation"]
|
||||
samples = val_dataset.select(range(batch_size))
|
||||
|
||||
@ -27,7 +27,8 @@ def prepare_text_inputs(model_name, batch_size=8):
|
||||
}
|
||||
return qa_real_test_inputs
|
||||
elif model_name == "BertForSequenceClassification" or model_name == "RobertaForSequenceClassification":
|
||||
yelp_dataset = load_dataset("fancyzhx/yelp_polarity")
|
||||
yelp_dataset = load_dataset("fancyzhx/yelp_polarity",
|
||||
trust_remote_code=True)
|
||||
val_dataset = yelp_dataset["test"]
|
||||
samples = val_dataset.select(range(batch_size))
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets==2.14.6
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
evaluate
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
rouge_score
|
||||
SentencePiece~=0.1.99
|
||||
evaluate
|
||||
|
||||
@ -11,4 +11,4 @@ sentencepiece>=0.1.99
|
||||
h5py~=3.12.1
|
||||
rouge_score
|
||||
nltk
|
||||
datasets==2.14.6
|
||||
datasets==3.1.0
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
protobuf
|
||||
rouge_score
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
SentencePiece>=0.1.99
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
transformers>=4.43.0
|
||||
datasets==2.14.6
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
|
||||
@ -312,7 +312,8 @@ def main(args):
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset_openweb = load_dataset("stas/openwebtext-10k",
|
||||
cache_dir=args.dataset_path)
|
||||
cache_dir=args.dataset_path,
|
||||
trust_remote_code=True)
|
||||
long_texts = get_long_texts(dataset_openweb) # generator
|
||||
|
||||
# get datapoints
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
evaluate
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
transformers>=4.39.0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
sentencepiece
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
evaluate
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.15.0
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
protobuf
|
||||
rouge_score
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
protobuf
|
||||
rouge_score
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
protobuf
|
||||
rouge_score
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
tiktoken==0.6.0
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.6
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
datasets~=2.14.6
|
||||
datasets==3.1.0
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
transformers>=4.31.0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
rouge_score
|
||||
evaluate
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets==2.14.6
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
sentencepiece==0.2.0
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets==2.14.5
|
||||
datasets==3.1.0
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
evaluate
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
SentencePiece>=0.1.99
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.16.1
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../../../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets==2.14.6
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
|
||||
@ -108,6 +108,7 @@ def load_dataset(args) -> datasets.Dataset:
|
||||
'timeout': aiohttp.ClientTimeout(total=3600)
|
||||
}
|
||||
},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
@ -2,6 +2,6 @@
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
nemo-toolkit[all]==2.0.0rc1
|
||||
megatron-core @ git+https://github.com/NVIDIA/Megatron-LM@core_r0.8.0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
|
||||
@ -20,7 +20,7 @@ def create_trtllm_magpie_calibration_dataset(output_dir: str,
|
||||
calib_size: int = 512) -> None:
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset(DATASET, split="train")
|
||||
dataset = load_dataset(DATASET, split="train", trust_remote_code=True)
|
||||
|
||||
def transform(conversation):
|
||||
value = '\n'.join(turn['value']
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
einops~=0.7.0
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
rouge_score
|
||||
sentencepiece~=0.1.99
|
||||
evaluate
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from quickstart_advanced import add_llm_args, setup_llm
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from tensorrt_llm.inputs import load_image, load_video
|
||||
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
|
||||
default_video_loader)
|
||||
|
||||
example_images = [
|
||||
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
|
||||
@ -27,92 +28,32 @@ example_video_prompts = [
|
||||
]
|
||||
|
||||
|
||||
def prepare_vila(args, inputs):
|
||||
def prepare_multimodal_inputs(model_dir: str,
|
||||
model_type: str,
|
||||
modality: str,
|
||||
prompts: List[str],
|
||||
media: List[str],
|
||||
image_data_format: str = "pt",
|
||||
num_frames: int = 8) -> List[Dict[str, Any]]:
|
||||
|
||||
def add_media_token(prompt, multi_modal_data):
|
||||
mm_tokens = ""
|
||||
if "image" in multi_modal_data:
|
||||
for _ in multi_modal_data["image"]:
|
||||
mm_tokens += "<image>"
|
||||
elif "video" in multi_modal_data:
|
||||
for _ in multi_modal_data["video"]:
|
||||
mm_tokens += "<vila/video>"
|
||||
return mm_tokens + prompt
|
||||
inputs = []
|
||||
if modality == "image":
|
||||
inputs = default_image_loader(prompts, media, image_data_format)
|
||||
elif modality == "video":
|
||||
inputs = default_video_loader(prompts, media, image_data_format,
|
||||
num_frames)
|
||||
else:
|
||||
raise ValueError(f"Unsupported modality: {modality}")
|
||||
|
||||
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
|
||||
|
||||
for input in inputs:
|
||||
input["prompt"] = add_media_token(input["prompt"],
|
||||
input["multi_modal_data"])
|
||||
return inputs
|
||||
|
||||
|
||||
def prepare_llava_next(args, inputs):
|
||||
processor = AutoProcessor.from_pretrained(args.model_dir)
|
||||
|
||||
# Single-image inference chat template. For multi-image template,
|
||||
# see https://huggingface.co/docs/transformers/en/model_doc/llava_next#multi-image-inference.
|
||||
def apply_template(prompt, multimodal_data):
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image"
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
return processor.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
for input in inputs:
|
||||
input["prompt"] = apply_template(input["prompt"],
|
||||
input["multi_modal_data"])
|
||||
return inputs
|
||||
|
||||
|
||||
def prepare_qwen2_vl(args, inputs):
|
||||
processor = AutoProcessor.from_pretrained(args.model_dir)
|
||||
|
||||
def apply_template(prompt, multimodal_data):
|
||||
content = [{
|
||||
"type": media_type
|
||||
} for media_type, items in multimodal_data.items()
|
||||
for _ in items] + [{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}]
|
||||
|
||||
conversation = [{"role": "user", "content": content}]
|
||||
return processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
for input in inputs:
|
||||
input["prompt"] = apply_template(input["prompt"],
|
||||
input["multi_modal_data"])
|
||||
return inputs
|
||||
|
||||
|
||||
MODEL_TYPE_MAP = {
|
||||
"llava_llama": prepare_vila,
|
||||
"llava_next": prepare_llava_next,
|
||||
"qwen2_vl": prepare_qwen2_vl,
|
||||
"qwen2_5_vl": prepare_qwen2_vl,
|
||||
}
|
||||
|
||||
|
||||
def add_multimodal_args(parser):
|
||||
parser.add_argument("--model_type",
|
||||
type=str,
|
||||
choices=MODEL_TYPE_MAP.keys(),
|
||||
choices=INPUT_FORMATTER_MAP.keys(),
|
||||
help="Model type.")
|
||||
parser.add_argument("--modality",
|
||||
type=str,
|
||||
@ -150,50 +91,16 @@ def main():
|
||||
llm, sampling_params = setup_llm(args)
|
||||
|
||||
image_format = "pt" # ["pt", "pil"]
|
||||
if args.modality == "image":
|
||||
prompts = args.prompt if args.prompt else example_image_prompts
|
||||
images = args.media if args.media else example_images
|
||||
if len(images) > len(prompts) and len(prompts) == 1:
|
||||
# 1 prompt + N media
|
||||
images = [images]
|
||||
inputs = [{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": [
|
||||
load_image(i, format=image_format, device="cuda")
|
||||
for i in image
|
||||
] if isinstance(image, list) else
|
||||
[load_image(image, format=image_format, device="cuda")]
|
||||
}
|
||||
} for prompt, image in zip(prompts, images)]
|
||||
elif args.modality == "video":
|
||||
prompts = args.prompt if args.prompt else example_video_prompts
|
||||
videos = args.media if args.media else example_videos
|
||||
if len(videos) > len(prompts) and len(prompts) == 1:
|
||||
# 1 prompt + N media
|
||||
videos = [videos]
|
||||
inputs = [{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"video": [
|
||||
load_video(
|
||||
i, args.num_frames, format=image_format, device="cuda")
|
||||
for i in video
|
||||
] if isinstance(video, list) else [
|
||||
load_video(video,
|
||||
args.num_frames,
|
||||
format=image_format,
|
||||
device="cuda")
|
||||
]
|
||||
}
|
||||
} for prompt, video in zip(prompts, videos)]
|
||||
if args.model_type is not None:
|
||||
model_type = args.model_type
|
||||
else:
|
||||
raise ValueError(f"Unsupported modality: {args.modality}")
|
||||
model_type = json.load(
|
||||
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
|
||||
assert model_type in INPUT_FORMATTER_MAP, f"Unsupported model_type: {model_type}"
|
||||
|
||||
model_type = json.load(open(os.path.join(llm._hf_model_dir,
|
||||
'config.json')))['model_type']
|
||||
assert model_type in MODEL_TYPE_MAP, f"Unsupported model_type: {model_type}"
|
||||
inputs = MODEL_TYPE_MAP[model_type](args, inputs)
|
||||
inputs = prepare_multimodal_inputs(args.model_dir, model_type,
|
||||
args.modality, args.prompt, args.media,
|
||||
image_format, args.num_frames)
|
||||
|
||||
outputs = llm.generate(inputs, sampling_params)
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets>=2.14.4
|
||||
datasets==3.1.0
|
||||
nemo-toolkit[all]==2.0.0rc1
|
||||
rouge_score
|
||||
transformers_stream_generator==0.0.4
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.16.0
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
transformers>=4.40.1
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.dev0
|
||||
datasets~=2.16.0
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
transformers>=4.45.0
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.16.0
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
transformers-stream-generator
|
||||
|
||||
@ -5,7 +5,7 @@ flax>=0.8.2
|
||||
jax~=0.4.23
|
||||
orbax-checkpoint==0.5.7
|
||||
transformers>=4.40.0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
rouge_score
|
||||
sentencepiece
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
datasets~=2.14.5
|
||||
datasets==3.1.0
|
||||
rouge_score
|
||||
sentencepiece>=0.1.99
|
||||
evaluate
|
||||
|
||||
@ -110,7 +110,8 @@ def main(args):
|
||||
dataset = load_dataset(dataset_name,
|
||||
dataset_revision,
|
||||
cache_dir=args.dataset_cache_dir,
|
||||
split=dataset_split)
|
||||
split=dataset_split,
|
||||
trust_remote_code=True)
|
||||
dataset = dataset.shuffle(args.random_seed)
|
||||
|
||||
max_batch_size = args.batch_size
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
-c ../constraints.txt
|
||||
tensorrt_llm>=0.0.0.dev0
|
||||
tiktoken
|
||||
datasets
|
||||
datasets==3.1.0
|
||||
kaldialign
|
||||
openai-whisper
|
||||
librosa
|
||||
|
||||
@ -564,7 +564,8 @@ if __name__ == '__main__':
|
||||
normalizer = EnglishTextNormalizer()
|
||||
dataset = load_dataset(args.dataset,
|
||||
args.dataset_name,
|
||||
split=args.dataset_split)
|
||||
split=args.dataset_split,
|
||||
trust_remote_code=True)
|
||||
if args.enable_warmup:
|
||||
results, total_duration = decode_dataset(
|
||||
model,
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
-r requirements.txt
|
||||
datasets==2.19.2
|
||||
einops
|
||||
graphviz
|
||||
mypy
|
||||
|
||||
@ -31,6 +31,8 @@ pydantic>=2.9.1
|
||||
pillow==10.3.0
|
||||
wheel<=0.45.1
|
||||
optimum
|
||||
# evaluate needs datasets>=2.0.0 which triggers datasets>3.1.0 which is not stable: https://github.com/huggingface/datasets/issues/7467
|
||||
datasets==3.1.0
|
||||
evaluate
|
||||
mpmath>=1.3.0
|
||||
click
|
||||
|
||||
@ -437,10 +437,7 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
inputs_embeds=input_embeds,
|
||||
return_context_logits=return_context_logits,
|
||||
mrope_config=mrope_config)
|
||||
logger.debug(
|
||||
f"output_ids: {(output_prob if output_prob.dim() == 2 else output_prob.unsqueeze(0)).argmax(dim=1).tolist()}"
|
||||
)
|
||||
logger.info(f'output shape: {output_prob.shape}')
|
||||
logger.debug(f'output shape: {output_prob.shape}')
|
||||
return output_prob
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from click_option_group import (MutuallyExclusiveOptionGroup, OptionGroup,
|
||||
|
||||
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
|
||||
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
|
||||
from tensorrt_llm.bench.build.build import get_model_config
|
||||
|
||||
# isort: off
|
||||
from tensorrt_llm.bench.benchmark.utils.general import (
|
||||
@ -21,7 +22,8 @@ from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig
|
||||
from tensorrt_llm.bench.dataclasses.general import BenchmarkEnvironment
|
||||
from tensorrt_llm.bench.dataclasses.reporting import ReportUtility
|
||||
from tensorrt_llm.bench.utils.data import (create_dataset_from_stream,
|
||||
initialize_tokenizer)
|
||||
initialize_tokenizer,
|
||||
update_metadata_for_multimodal)
|
||||
from tensorrt_llm.llmapi import LLM, CapacitySchedulerPolicy
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
@ -92,6 +94,20 @@ from tensorrt_llm.sampling_params import SamplingParams
|
||||
required=False,
|
||||
help="Pass in a dataset file for parsing instead of stdin.",
|
||||
)
|
||||
@optgroup.option(
|
||||
"--modality",
|
||||
type=click.Choice(["image", "video"]),
|
||||
default=None,
|
||||
help="Modality of the multimodal requests.",
|
||||
)
|
||||
@optgroup.option(
|
||||
"--max_input_len",
|
||||
type=int,
|
||||
default=4096,
|
||||
help=
|
||||
"Maximum input sequence length to use for multimodal models. This is used only when --modality "
|
||||
"is specified since the actual number of vision tokens is unknown before the model is run.",
|
||||
)
|
||||
@optgroup.option(
|
||||
"--num_requests",
|
||||
type=int,
|
||||
@ -194,6 +210,9 @@ def throughput_command(
|
||||
engine_dir: Path = params.pop("engine_dir")
|
||||
concurrency: int = params.pop("concurrency")
|
||||
backend: str = params.get("backend")
|
||||
modality: str = params.pop("modality")
|
||||
max_input_len: int = params.pop("max_input_len")
|
||||
model_type = get_model_config(model, checkpoint_path).model_type
|
||||
|
||||
# Reporting options
|
||||
report_json: Path = params.pop("report_json")
|
||||
@ -209,14 +228,23 @@ def throughput_command(
|
||||
# Dataset Loading and Preparation
|
||||
with open(dataset_path, "r") as dataset:
|
||||
metadata, requests = create_dataset_from_stream(
|
||||
tokenizer, dataset, num_requests=num_requests)
|
||||
tokenizer,
|
||||
dataset,
|
||||
num_requests=num_requests,
|
||||
model_dir=checkpoint_path,
|
||||
model_type=model_type,
|
||||
modality=modality,
|
||||
max_input_seq_len_for_multimodal=max_input_len)
|
||||
metadata.dataset_path = dataset_path
|
||||
params["target_input_len"] = params.get(
|
||||
"target_input_len") or metadata.avg_isl
|
||||
params["target_output_len"] = params.get(
|
||||
"target_output_len") or metadata.avg_osl
|
||||
|
||||
if modality is None:
|
||||
# Log dataset info
|
||||
# NOTE: This table is only accurate for non-multimodal models.
|
||||
# The accurate table for multimodal models will be logged after the benchmark is done.
|
||||
logger.info(metadata.get_summary_for_print())
|
||||
|
||||
# Engine configuration parsing
|
||||
@ -294,8 +322,12 @@ def throughput_command(
|
||||
warmup_dataset = generate_warmup_dataset(requests, warmup)
|
||||
logger.info("Running warmup.")
|
||||
asyncio.run(
|
||||
async_benchmark(llm, sampling_params, warmup_dataset, False,
|
||||
concurrency))
|
||||
async_benchmark(llm,
|
||||
sampling_params,
|
||||
warmup_dataset,
|
||||
False,
|
||||
concurrency,
|
||||
modality=modality))
|
||||
# WAR: IterationResult is a singleton tied to the executor.
|
||||
# Since the benchmark calls asyncio.run() multiple times (e.g., during warmup),
|
||||
# we must reset it to ensure it attaches to the correct event loop.
|
||||
@ -304,10 +336,19 @@ def throughput_command(
|
||||
|
||||
with iteration_writer.capture():
|
||||
statistics = asyncio.run(
|
||||
async_benchmark(llm, sampling_params, requests, streaming,
|
||||
concurrency, iteration_writer.full_address))
|
||||
async_benchmark(llm,
|
||||
sampling_params,
|
||||
requests,
|
||||
streaming,
|
||||
concurrency,
|
||||
iteration_writer.full_address,
|
||||
modality=modality))
|
||||
|
||||
logger.info(f"Benchmark done. Reporting results...")
|
||||
if modality is not None:
|
||||
# For multimodal models, we need to update the metadata with the correct input lengths
|
||||
metadata = update_metadata_for_multimodal(metadata, statistics)
|
||||
|
||||
report_utility = ReportUtility(statistics, metadata, runtime_config,
|
||||
logger, kwargs, streaming)
|
||||
if report_json:
|
||||
|
||||
@ -23,7 +23,8 @@ class LlmManager:
|
||||
llm: LLM,
|
||||
outbox: asyncio.Queue[PerfItemTuple],
|
||||
streaming: bool,
|
||||
concurrency: int = -1) -> None:
|
||||
concurrency: int = -1,
|
||||
modality: Optional[str] = None) -> None:
|
||||
self.llm = llm
|
||||
self._inbox: asyncio.Queue[Tuple[InferenceRequest,
|
||||
SamplingParams]] = asyncio.Queue()
|
||||
@ -38,6 +39,7 @@ class LlmManager:
|
||||
concurrency) if concurrency > 0 else None
|
||||
self.streaming = streaming
|
||||
self.request_seen = asyncio.Event()
|
||||
self.modality = modality
|
||||
|
||||
async def process_request(self, request: InferenceRequest,
|
||||
sampling_params: SamplingParams):
|
||||
@ -50,7 +52,7 @@ class LlmManager:
|
||||
time_on_first_token = None
|
||||
# Schedule the request in the LLM API (asynchronously)
|
||||
output: RequestOutput = self.llm.generate_async(
|
||||
request.input_ids,
|
||||
request.input_ids if self.modality is None else request.prompt,
|
||||
sampling_params=sampling_params,
|
||||
streaming=self.streaming)
|
||||
if self.streaming:
|
||||
@ -70,7 +72,7 @@ class LlmManager:
|
||||
start_timestamp=request_start_timestamp,
|
||||
end_timestamp=response_end_timestamp,
|
||||
request_id=response.request_id,
|
||||
num_input_tokens=len(request.input_ids),
|
||||
num_input_tokens=len(output.prompt_token_ids),
|
||||
response_is_final=response.finished,
|
||||
error=False,
|
||||
tokens=tokens,
|
||||
@ -201,6 +203,7 @@ async def async_benchmark(
|
||||
streaming: bool,
|
||||
concurrency: int = -1,
|
||||
iteration_log_addr: str = None,
|
||||
modality: Optional[str] = None,
|
||||
) -> StatsKeeper:
|
||||
outbox = asyncio.Queue()
|
||||
statistics = StatsKeeper()
|
||||
@ -208,7 +211,11 @@ async def async_benchmark(
|
||||
|
||||
try:
|
||||
logger.info("Starting benchmarking async task.")
|
||||
backend = LlmManager(llm, outbox, streaming, concurrency=concurrency)
|
||||
backend = LlmManager(llm,
|
||||
outbox,
|
||||
streaming,
|
||||
concurrency=concurrency,
|
||||
modality=modality)
|
||||
backend.run(iteration_addr=iteration_log_addr)
|
||||
|
||||
enqueue_task = asyncio.create_task(
|
||||
|
||||
@ -116,6 +116,7 @@ class ModelConfig(BaseModel):
|
||||
setting calculation.
|
||||
"""
|
||||
name: str
|
||||
model_type: str
|
||||
param_count: int
|
||||
num_hidden_layers: int = Field(validation_alias=AliasChoices(
|
||||
"num_hidden_layers",
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from pydantic import (AliasChoices, BaseModel, Field, computed_field,
|
||||
model_validator)
|
||||
@ -17,7 +17,7 @@ class BenchmarkEnvironment(BaseModel):
|
||||
|
||||
class InferenceRequest(BaseModel):
|
||||
task_id: int
|
||||
prompt: Optional[str] = None
|
||||
prompt: Optional[Union[str, Any]] = None
|
||||
output_tokens: int
|
||||
input_ids: Optional[List[int]] = Field(
|
||||
alias=AliasChoices("input_ids", "logits"))
|
||||
|
||||
@ -1,12 +1,36 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import List, TextIO, Tuple
|
||||
from typing import Any, Dict, List, TextIO, Tuple
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata,
|
||||
InferenceRequest)
|
||||
from tensorrt_llm.bench.dataclasses.statistics import PercentileStats
|
||||
from tensorrt_llm.inputs import (INPUT_FORMATTER_MAP, default_image_loader,
|
||||
default_video_loader)
|
||||
|
||||
|
||||
def prepare_multimodal_inputs(model_dir: str,
|
||||
model_type: str,
|
||||
modality: str,
|
||||
prompts: List[str],
|
||||
media: List[str],
|
||||
image_data_format: str = "pt",
|
||||
num_frames: int = 8) -> List[Dict[str, Any]]:
|
||||
|
||||
inputs = []
|
||||
if modality == "image":
|
||||
inputs = default_image_loader(prompts, media, image_data_format)
|
||||
elif modality == "video":
|
||||
inputs = default_video_loader(prompts, media, image_data_format,
|
||||
num_frames)
|
||||
else:
|
||||
raise ValueError(f"Unsupported modality: {modality}")
|
||||
|
||||
inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def initialize_tokenizer(model_name: str) -> PreTrainedTokenizer:
|
||||
@ -36,6 +60,10 @@ def create_dataset_from_stream(
|
||||
max_input_length: int = 0,
|
||||
max_output_length: int = 0,
|
||||
num_requests: int = 0,
|
||||
model_dir: str = None,
|
||||
model_type: str = None,
|
||||
modality: str = None,
|
||||
max_input_seq_len_for_multimodal: int = 4096,
|
||||
) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
|
||||
"""Generate metadata and a list of requests to drive benchmarking.
|
||||
|
||||
@ -83,13 +111,30 @@ def create_dataset_from_stream(
|
||||
# Each line should be a complete JSON dictionary with no indentation
|
||||
# or newline characters.
|
||||
data = json.loads(line)
|
||||
if modality is not None:
|
||||
# Multimodal data
|
||||
assert modality in [
|
||||
"image", "video"
|
||||
], f"Modality must be one of ['image', 'video'] but got {modality}."
|
||||
|
||||
prompt = data.get("prompt") # cannot be None
|
||||
media_paths = data.get("media_paths", None)
|
||||
inputs = prepare_multimodal_inputs(
|
||||
model_dir,
|
||||
model_type,
|
||||
modality,
|
||||
prompts=[prompt],
|
||||
media=media_paths) # list of dicts
|
||||
logits = None # cannot tokenize multi-modal data, handled by preprocessor
|
||||
prompt = inputs[0]
|
||||
else:
|
||||
logits = data.get("input_ids", data.get("logits", None))
|
||||
prompt = data.get("prompt", None)
|
||||
task_id = data["task_id"]
|
||||
osl = data["output_tokens"]
|
||||
# If the request comes in with logits, just use the provided.
|
||||
# Otherwise we need to tokenize it.
|
||||
logits = tokenize(prompt)["input_ids"] if logits is None else logits
|
||||
task_id = data["task_id"]
|
||||
osl = data["output_tokens"]
|
||||
|
||||
request = InferenceRequest(
|
||||
task_id=task_id,
|
||||
@ -97,8 +142,13 @@ def create_dataset_from_stream(
|
||||
output_tokens=output_limiter(osl),
|
||||
input_ids=logits,
|
||||
)
|
||||
all_isl.append(len(logits))
|
||||
all_osl.append(osl)
|
||||
if modality is not None:
|
||||
cur_isl = max_input_seq_len_for_multimodal # NOTE: actual sequence length is unknown until the model is run
|
||||
all_isl.append(cur_isl)
|
||||
all_seq_len.append(cur_isl + osl)
|
||||
else:
|
||||
all_isl.append(len(logits))
|
||||
all_seq_len.append(len(logits) + osl)
|
||||
dataset.append(request)
|
||||
|
||||
@ -115,3 +165,31 @@ def create_dataset_from_stream(
|
||||
)
|
||||
|
||||
return metadata, dataset
|
||||
|
||||
|
||||
def update_metadata_for_multimodal(metadata, statistics) -> DatasetMetadata:
|
||||
"""Update the metadata from benchmark statistics. Only used for multimodal models.
|
||||
|
||||
Args:
|
||||
metadata (DatasetMetadata): The metadata to update.
|
||||
statistics (StatsKeeper): The statistics to update the metadata with.
|
||||
|
||||
Returns:
|
||||
DatasetMetadata: The updated metadata.
|
||||
"""
|
||||
all_isl = []
|
||||
all_osl = []
|
||||
all_seq_len = []
|
||||
for request in statistics.requests.values():
|
||||
all_isl.append(request.num_input_tokens)
|
||||
all_osl.append(request.num_total_output_tokens)
|
||||
all_seq_len.append(request.num_input_tokens +
|
||||
request.num_total_output_tokens)
|
||||
isl_stats = PercentileStats.from_iterable(all_isl)
|
||||
osl_stats = PercentileStats.from_iterable(all_osl)
|
||||
seq_len_stats = PercentileStats.from_iterable(all_seq_len)
|
||||
metadata.isl_stats = isl_stats
|
||||
metadata.osl_stats = osl_stats
|
||||
metadata.seq_len_stats = seq_len_stats
|
||||
|
||||
return metadata
|
||||
|
||||
@ -36,7 +36,10 @@ class CnnDailymail(Evaluator):
|
||||
system_prompt: Optional[str] = None):
|
||||
super().__init__(apply_chat_template=apply_chat_template,
|
||||
system_prompt=system_prompt)
|
||||
self.data = datasets.load_dataset(dataset_path, "3.0.0", split="test")
|
||||
self.data = datasets.load_dataset(dataset_path,
|
||||
"3.0.0",
|
||||
split="test",
|
||||
trust_remote_code=True)
|
||||
self.data = self.data.shuffle(random_seed)
|
||||
if num_samples is None:
|
||||
self.num_samples = self.data.num_rows
|
||||
|
||||
@ -1,10 +1,15 @@
|
||||
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
|
||||
from .registry import (ExtraProcessedInputs, InputProcessor,
|
||||
create_input_processor, register_input_processor)
|
||||
from .utils import load_image, load_video
|
||||
from .utils import (INPUT_FORMATTER_MAP, default_image_loader,
|
||||
default_video_loader, format_llava_next_input,
|
||||
format_qwen2_vl_input, format_vila_input, load_image,
|
||||
load_video)
|
||||
|
||||
__all__ = [
|
||||
"PromptInputs", "prompt_inputs", "TextPrompt", "TokensPrompt",
|
||||
"InputProcessor", "create_input_processor", "register_input_processor",
|
||||
"ExtraProcessedInputs", "load_image", "load_video"
|
||||
"ExtraProcessedInputs", "load_image", "load_video", "INPUT_FORMATTER_MAP",
|
||||
"default_image_loader", "default_video_loader", "format_vila_input",
|
||||
"format_llava_next_input", "format_qwen2_vl_input"
|
||||
]
|
||||
|
||||
@ -6,6 +6,7 @@ import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToTensor
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
def load_image(image: str,
|
||||
@ -67,3 +68,158 @@ def load_video(
|
||||
device=device) if format == "pt" else frames[index]
|
||||
for index in indices if index in frames
|
||||
]
|
||||
|
||||
|
||||
"""
|
||||
VLM input preparation.
|
||||
"""
|
||||
|
||||
|
||||
def format_vila_input(model_dir, inputs):
|
||||
"""
|
||||
This function formats the input for the VILA/NVILA VL model.
|
||||
|
||||
Arguments:
|
||||
model_dir: The directory of the model to load any preprocessor.
|
||||
inputs: The list of inputs to format.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
|
||||
"""
|
||||
|
||||
def add_media_token(prompt, multi_modal_data):
|
||||
mm_tokens = ""
|
||||
if "image" in multi_modal_data:
|
||||
for _ in multi_modal_data["image"]:
|
||||
mm_tokens += "<image>"
|
||||
elif "video" in multi_modal_data:
|
||||
for _ in multi_modal_data["video"]:
|
||||
mm_tokens += "<vila/video>"
|
||||
return mm_tokens + prompt
|
||||
|
||||
for input in inputs:
|
||||
input["prompt"] = add_media_token(input["prompt"],
|
||||
input["multi_modal_data"])
|
||||
return inputs
|
||||
|
||||
|
||||
def format_llava_next_input(model_dir, inputs):
|
||||
"""
|
||||
This function formats the input for the Llava Next VL model.
|
||||
|
||||
Arguments:
|
||||
model_dir: The directory of the model to load any preprocessor.
|
||||
inputs: The list of inputs to format.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
|
||||
"""
|
||||
processor = AutoProcessor.from_pretrained(model_dir)
|
||||
|
||||
# Single-image inference chat template. For multi-image template,
|
||||
# see https://huggingface.co/docs/transformers/en/model_doc/llava_next#multi-image-inference.
|
||||
def apply_template(prompt, multimodal_data):
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image"
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
return processor.apply_chat_template(
|
||||
conversation,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
for input in inputs:
|
||||
input["prompt"] = apply_template(input["prompt"],
|
||||
input["multi_modal_data"])
|
||||
return inputs
|
||||
|
||||
|
||||
def format_qwen2_vl_input(model_dir, inputs):
|
||||
"""
|
||||
This function formats the input for the Qwen2/Qwen2.5 VL model.
|
||||
|
||||
Arguments:
|
||||
model_dir: The directory of the model to load any preprocessor.
|
||||
inputs: The list of inputs to format.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries where "prompt" data is modified to a TextPrompt that combines text prompt and multimodal data.
|
||||
"""
|
||||
processor = AutoProcessor.from_pretrained(model_dir)
|
||||
|
||||
def apply_template(prompt, multimodal_data):
|
||||
content = [{
|
||||
"type": media_type
|
||||
} for media_type, items in multimodal_data.items()
|
||||
for _ in items] + [{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}]
|
||||
|
||||
conversation = [{"role": "user", "content": content}]
|
||||
# print(conversation)
|
||||
return processor.apply_chat_template(
|
||||
conversation,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
for input in inputs:
|
||||
input["prompt"] = apply_template(input["prompt"],
|
||||
input["multi_modal_data"])
|
||||
return inputs
|
||||
|
||||
|
||||
def default_image_loader(prompts, images, image_data_format="pt"):
|
||||
if len(images) > len(prompts) and len(prompts) == 1:
|
||||
# 1 prompt + N media
|
||||
images = [images]
|
||||
inputs = [{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": [
|
||||
load_image(i, format=image_data_format, device="cuda")
|
||||
for i in image
|
||||
] if isinstance(image, list) else
|
||||
[load_image(image, format=image_data_format, device="cuda")]
|
||||
}
|
||||
} for prompt, image in zip(prompts, images)]
|
||||
return inputs
|
||||
|
||||
|
||||
def default_video_loader(prompts, videos, image_data_format="pt", num_frames=8):
|
||||
if len(videos) > len(prompts) and len(prompts) == 1:
|
||||
# 1 prompt + N media
|
||||
videos = [videos]
|
||||
inputs = [{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"video": [
|
||||
load_video(
|
||||
i, num_frames, format=image_data_format, device="cuda")
|
||||
for i in video
|
||||
] if isinstance(video, list) else [
|
||||
load_video(
|
||||
video, num_frames, format=image_data_format, device="cuda")
|
||||
]
|
||||
}
|
||||
} for prompt, video in zip(prompts, videos)]
|
||||
return inputs
|
||||
|
||||
|
||||
INPUT_FORMATTER_MAP = {
|
||||
"llava_llama": format_vila_input,
|
||||
"llava_next": format_llava_next_input,
|
||||
"qwen2_vl": format_qwen2_vl_input,
|
||||
"qwen2_5_vl": format_qwen2_vl_input,
|
||||
}
|
||||
|
||||
@ -306,6 +306,7 @@ def load_calib_dataset(dataset_name_or_dir: str,
|
||||
dataset = load_dataset(dataset_name_or_dir,
|
||||
name=config_name,
|
||||
split=split,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs)
|
||||
return dataset[key]
|
||||
|
||||
|
||||
@ -384,20 +384,26 @@ def get_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||
split="train")
|
||||
split="train",
|
||||
trust_remote_code=True)
|
||||
dataset = dataset["text"][:calib_size]
|
||||
elif "scienceqa" in dataset_name_or_dir.lower(
|
||||
) or "science_qa" in dataset_name_or_dir.lower():
|
||||
if os.path.isdir(dataset_name_or_dir):
|
||||
dataset = load_dataset(dataset_name_or_dir, split="train")
|
||||
dataset = load_dataset(dataset_name_or_dir,
|
||||
split="train",
|
||||
trust_remote_code=True)
|
||||
else:
|
||||
dataset = load_dataset("derek-thomas/ScienceQA", split="train")
|
||||
dataset = load_dataset("derek-thomas/ScienceQA",
|
||||
split="train",
|
||||
trust_remote_code=True)
|
||||
dataset = dataset.select(range(calib_size))
|
||||
elif "cnn_dailymail" in dataset_name_or_dir:
|
||||
dataset = load_dataset(
|
||||
dataset_name_or_dir,
|
||||
name="3.0.0",
|
||||
split="train",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
dataset = dataset["article"][:calib_size]
|
||||
elif os.path.isdir(dataset_name_or_dir):
|
||||
@ -405,7 +411,9 @@ def get_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
|
||||
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
|
||||
"assuming the calibration data are in the train split and text column."
|
||||
)
|
||||
dataset = load_dataset(dataset_name_or_dir, split="train")
|
||||
dataset = load_dataset(dataset_name_or_dir,
|
||||
split="train",
|
||||
trust_remote_code=True)
|
||||
dataset = dataset["text"][:calib_size]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -993,22 +1001,29 @@ def get_nemo_calib_dataloader(dataset_name_or_dir="cnn_dailymail",
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||
split="train")
|
||||
split="train",
|
||||
trust_remote_code=True)
|
||||
text_column = "text"
|
||||
elif "wikitext" in dataset_name_or_dir:
|
||||
dataset = load_dataset(dataset_name_or_dir,
|
||||
"wikitext-103-v1",
|
||||
split="train")
|
||||
split="train",
|
||||
trust_remote_code=True)
|
||||
text_column = "text"
|
||||
elif "cnn_dailymail" in dataset_name_or_dir:
|
||||
dataset = load_dataset(dataset_name_or_dir, name="3.0.0", split="train")
|
||||
dataset = load_dataset(dataset_name_or_dir,
|
||||
name="3.0.0",
|
||||
split="train",
|
||||
trust_remote_code=True)
|
||||
text_column = "article"
|
||||
elif os.path.isdir(dataset_name_or_dir):
|
||||
logger.info(
|
||||
f"Recognized local dataset repo {dataset_name_or_dir} for calibration; "
|
||||
"assuming the calibration data are in the train split and text column."
|
||||
)
|
||||
dataset = load_dataset(dataset_name_or_dir, split="train")
|
||||
dataset = load_dataset(dataset_name_or_dir,
|
||||
split="train",
|
||||
trust_remote_code=True)
|
||||
text_column = "text"
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
||||
@ -366,8 +366,11 @@ def test_tokenizer_decode_incrementally(tokenizer_dir: str, threshold: float):
|
||||
num_samples = 100
|
||||
cnn_dailymail = datasets.load_dataset(cnn_dailymail_path,
|
||||
name='3.0.0',
|
||||
split='train')
|
||||
alpaca_chinese = datasets.load_dataset(alpaca_chinese_path, split='train')
|
||||
split='train',
|
||||
trust_remote_code=True)
|
||||
alpaca_chinese = datasets.load_dataset(alpaca_chinese_path,
|
||||
split='train',
|
||||
trust_remote_code=True)
|
||||
dataset = cnn_dailymail['article'][:num_samples // 2] + alpaca_chinese[
|
||||
'output_zh'][:num_samples // 2]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user