TensorRT-LLMs/examples/multimodal/build_visual_engine.py
Kaiyu Xie c89653021e
Update TensorRT-LLM (20240116) (#891)
* Update TensorRT-LLM

---------

Co-authored-by: Eddie-Wang1120 <81598289+Eddie-Wang1120@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-01-16 20:03:11 +08:00

209 lines
7.7 KiB
Python

import argparse
import os
from time import time
import tensorrt as trt
import torch
from PIL import Image
from transformers import Blip2ForConditionalGeneration, Blip2Processor
def export_visual_wrapper_onnx(visual_wrapper, image, output_dir):
logger.log(trt.Logger.INFO, "Exporting onnx")
torch.onnx.export(visual_wrapper,
image,
f'{output_dir}/visual_encoder.onnx',
opset_version=17,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {
0: 'batch'
}})
def build_trt_engine(part_id,
img_height,
img_width,
output_dir,
minBS=1,
optBS=2,
maxBS=4):
part_name = 'visual_encoder' if part_id == 0 else 'Qformer'
onnx_file = '%s/%s.onnx' % (output_dir, part_name)
engine_file = '%s/%s_fp16.engine' % (output_dir, part_name)
logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
profile = builder.create_optimization_profile()
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16)
parser = trt.OnnxParser(network, logger)
with open(onnx_file, 'rb') as model:
if not parser.parse(model.read(), "/".join(onnx_file.split("/"))):
logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file)
for error in range(parser.num_errors):
logger.log(trt.Logger.ERROR, parser.get_error(error))
logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file)
nBS = -1
nMinBS = minBS
nOptBS = optBS
nMaxBS = maxBS
logger.log(trt.Logger.INFO,
f"Processed image dims {img_height}x{img_width}")
if part_id == 0: # Feature extractor
H, W = img_height, img_width
inputT = network.get_input(0)
inputT.shape = [nBS, 3, H, W]
profile.set_shape(inputT.name, [nMinBS, 3, H, W], [nOptBS, 3, H, W],
[nMaxBS, 3, H, W])
elif part_id == 1: # BLIP Qformer
inputT = network.get_input(0)
dims = [32, 768]
inputT.shape = [nBS] + dims
profile.set_shape(inputT.name, [nMinBS] + dims, [nOptBS] + dims,
[nMaxBS] + dims)
inputT = network.get_input(1)
dims = [257, 1408]
inputT.shape = [nBS] + dims
profile.set_shape(inputT.name, [nMinBS] + dims, [nOptBS] + dims,
[nMaxBS] + dims)
inputT = network.get_input(2)
inputT.shape = [nBS, 257]
profile.set_shape(inputT.name, [nMinBS, inputT.shape[1]],
[nOptBS, inputT.shape[1]], [nMaxBS, inputT.shape[1]])
else:
raise RuntimeError("Invalid part id")
config.add_optimization_profile(profile)
t0 = time()
engine_string = builder.build_serialized_network(network, config)
t1 = time()
if engine_string is None:
raise RuntimeError("Failed building %s" % (engine_file))
else:
logger.log(trt.Logger.INFO,
"Succeeded building %s in %d s" % (engine_file, t1 - t0))
with open(engine_file, 'wb') as f:
f.write(engine_string)
def build_blip_engine(args):
model_type = 'Salesforce/blip2-' + args.model_name
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
processor = Blip2Processor.from_pretrained(model_type)
model = Blip2ForConditionalGeneration.from_pretrained(
model_type, torch_dtype=torch.float16)
model.to(device)
raw_image = Image.new('RGB', [10, 10]) # dummy image
# image = vis_processors["eval"](image).unsqueeze(0).to(device)
prompt = "Question: what is this? Answer:"
inputs = processor(raw_image, prompt,
return_tensors="pt").to(device, torch.float16)
image = inputs['pixel_values']
visual_wrapper = model.vision_model
image_embeds = visual_wrapper(image)[0]
export_visual_wrapper_onnx(visual_wrapper, image, args.output_dir)
build_trt_engine(0, image.shape[2], image.shape[3], args.output_dir)
class QformerWrapper(torch.nn.Module):
def __init__(self, Qformer, projector):
super().__init__()
self.model = Qformer
self.projector = projector
def forward(self, query_tokens, image_embeds, image_atts):
query_output = self.model(query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True)
return self.projector(query_output.last_hidden_state)
projector = model.language_projection
q_wrapper = QformerWrapper(model.qformer, projector)
image_atts = torch.ones(image_embeds.size()[:-1],
dtype=torch.long).to(image.device)
query_tokens = model.query_tokens.expand(image_embeds.shape[0], -1, -1)
torch.onnx.export(
q_wrapper, (query_tokens, image_embeds, image_atts),
f'{args.output_dir}/Qformer.onnx',
opset_version=17,
input_names=['query_tokens', 'image_embeds', 'image_atts'],
output_names=['query_output'],
dynamic_axes={
'query_tokens': {
0: 'batch'
},
'image_embeds': {
0: 'batch'
},
'image_atts': {
0: 'batch'
}
})
build_trt_engine(1, image.shape[2], image.shape[3], args.output_dir)
def build_llava_engine(args):
# Import these here to avoid installing llava when running blip models only
from llava.mm_utils import get_model_name_from_path
from llava.model.builder import load_pretrained_model
model_name = get_model_name_from_path(args.model_path)
_, model, image_processor, _ = load_pretrained_model(
args.model_path, None, model_name)
image = Image.new('RGB', [10, 10]) # dummy image
image = image_processor(image, return_tensors='pt')['pixel_values']
image = image.half().to(device)
visual_wrapper = torch.nn.Sequential(model.get_vision_tower(),
model.get_model().mm_projector)
export_visual_wrapper_onnx(visual_wrapper, image, args.output_dir)
build_trt_engine(0, image.shape[2], image.shape[3], args.output_dir)
if __name__ == '__main__':
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
logger = trt.Logger(trt.Logger.ERROR)
parser = argparse.ArgumentParser()
parser.add_argument('--model_name',
type=str,
default=None,
help="Model name")
parser.add_argument('--model_path',
type=str,
default=None,
help="Huggingface repo or local directory with weights")
parser.add_argument('--output_dir',
type=str,
default='visual_engines',
help="Directory where visual TRT engines are saved")
args = parser.parse_args()
args.output_dir = args.output_dir + "/" + args.model_name
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if args.model_name in ['opt-2.7b', 'flan-t5-xl']:
build_blip_engine(args)
elif 'llava' in args.model_name:
build_llava_engine(args)
else:
raise RuntimeError(f"Invalid model name {args.model_name}")