mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* blossom-ci.yml: run vulnerability scan on blossom * open source efb18c1256f8c9c3d47b7d0c740b83e5d5ebe0ec --------- Co-authored-by: niukuo <6831097+niukuo@users.noreply.github.com> Co-authored-by: pei0033 <59505847+pei0033@users.noreply.github.com> Co-authored-by: Kyungmin Lee <30465912+lkm2835@users.noreply.github.com> Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
1998 lines
93 KiB
Python
1998 lines
93 KiB
Python
import json
|
||
import os
|
||
import sys
|
||
from io import BytesIO
|
||
|
||
import requests
|
||
|
||
# isort: off
|
||
import torch
|
||
import numpy as np
|
||
# isort: on
|
||
import math
|
||
from typing import Optional, Tuple
|
||
|
||
import torch.nn.functional as F
|
||
from huggingface_hub import hf_hub_download
|
||
from PIL import Image
|
||
from safetensors import safe_open
|
||
from torch import nn
|
||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoProcessor,
|
||
AutoTokenizer)
|
||
|
||
from .. import profiler
|
||
from .._utils import (mpi_rank, str_dtype_to_torch, str_dtype_to_trt,
|
||
supports_inflight_batching, torch_dtype_to_trt,
|
||
trt_dtype_to_torch)
|
||
from ..functional import RopeEmbeddingUtils, RotaryScalingType
|
||
from ..layers import MropeParams
|
||
from ..logger import logger
|
||
from .enc_dec_model_runner import EncDecModelRunner
|
||
from .model_runner import ModelRunner
|
||
from .session import Session, TensorInfo
|
||
|
||
try:
|
||
import tensorrt_llm.bindings # NOQA
|
||
PYTHON_BINDINGS = True
|
||
except ImportError:
|
||
PYTHON_BINDINGS = False
|
||
|
||
if PYTHON_BINDINGS:
|
||
from .model_runner_cpp import ModelRunnerCpp
|
||
|
||
|
||
class LlavaNextUtils:
|
||
# https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
|
||
|
||
@staticmethod
|
||
def select_best_resolution(original_size, possible_resolutions):
|
||
"""
|
||
Selects the best resolution from a list of possible resolutions based on the original size.
|
||
|
||
Args:
|
||
original_size (tuple): The original size of the image in the format (width, height).
|
||
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
||
|
||
Returns:
|
||
tuple: The best fit resolution in the format (width, height).
|
||
"""
|
||
original_width, original_height = original_size
|
||
best_fit = None
|
||
max_effective_resolution = 0
|
||
min_wasted_resolution = float('inf')
|
||
|
||
for width, height in possible_resolutions:
|
||
scale = min(width / original_width, height / original_height)
|
||
downscaled_width, downscaled_height = int(
|
||
original_width * scale), int(original_height * scale)
|
||
effective_resolution = min(downscaled_width * downscaled_height,
|
||
original_width * original_height)
|
||
wasted_resolution = (width * height) - effective_resolution
|
||
|
||
if effective_resolution > max_effective_resolution or (
|
||
effective_resolution == max_effective_resolution
|
||
and wasted_resolution < min_wasted_resolution):
|
||
max_effective_resolution = effective_resolution
|
||
min_wasted_resolution = wasted_resolution
|
||
best_fit = (width, height)
|
||
|
||
return best_fit
|
||
|
||
@staticmethod
|
||
def get_anyres_image_grid_shape(image_size,
|
||
patch_size,
|
||
image_grid_pinpoints=None):
|
||
"""
|
||
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
||
|
||
Args:
|
||
image_size (tuple): The size of the input image in the format (width, height).
|
||
patch_size (int): The size of each image patch.
|
||
|
||
Returns:
|
||
tuple: The shape of the image patch grid in the format (width, height).
|
||
"""
|
||
if image_grid_pinpoints is None:
|
||
image_grid_pinpoints = [[336, 672], [672, 336], [672, 672],
|
||
[1008, 336], [336, 1008]]
|
||
width, height = LlavaNextUtils.select_best_resolution(
|
||
image_size, image_grid_pinpoints)
|
||
return width // patch_size, height // patch_size
|
||
|
||
@staticmethod
|
||
def unpad_image(tensor, original_size):
|
||
"""
|
||
Unpads a PyTorch tensor of a padded and resized image.
|
||
|
||
Args:
|
||
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
||
original_size (tuple): The original size of the image (width, height).
|
||
|
||
Returns:
|
||
torch.Tensor: The unpadded image tensor.
|
||
"""
|
||
original_width, original_height = original_size
|
||
current_height, current_width = tensor.shape[1:]
|
||
|
||
original_aspect_ratio = original_width / original_height
|
||
current_aspect_ratio = current_width / current_height
|
||
|
||
if original_aspect_ratio > current_aspect_ratio:
|
||
scale_factor = current_width / original_width
|
||
new_height = int(original_height * scale_factor)
|
||
padding = (current_height - new_height) // 2
|
||
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
||
else:
|
||
scale_factor = current_height / original_height
|
||
new_width = int(original_width * scale_factor)
|
||
padding = (current_width - new_width) // 2
|
||
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
||
|
||
return unpadded_tensor
|
||
|
||
@staticmethod
|
||
def rearrange_image_features(image_feature, image_newline, image_size):
|
||
"""
|
||
Combine PyTorch feature grids from image patches.
|
||
|
||
Args:
|
||
image_feature (torch.Tensor): The feature grids, assumed to be in NxCxHxW format.
|
||
image_newline (torch.Tensor): The newline embedding.
|
||
image_size (tuple): Size of the original image (width, height).
|
||
"""
|
||
CLIP_IMAGE_SIZE = 336
|
||
CLIP_PATCH_SIZE = 14
|
||
NUM_PATCHES_PER_SIDE = CLIP_IMAGE_SIZE // CLIP_PATCH_SIZE
|
||
if image_feature.shape[0] == 1:
|
||
return torch.cat((image_feature, image_newline[None]), dim=0)
|
||
|
||
base_image_feature = image_feature[0]
|
||
image_feature = image_feature[1:]
|
||
height = width = NUM_PATCHES_PER_SIDE
|
||
assert height * width == base_image_feature.shape[0]
|
||
|
||
num_patch_width, num_patch_height = LlavaNextUtils.get_anyres_image_grid_shape(
|
||
image_size, CLIP_IMAGE_SIZE)
|
||
image_feature = image_feature.view(num_patch_height, num_patch_width,
|
||
height, width, -1)
|
||
|
||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||
image_feature = LlavaNextUtils.unpad_image(image_feature, image_size)
|
||
image_feature = torch.cat(
|
||
(image_feature, image_newline[:, None, None].expand(
|
||
*image_feature.shape[:-1], 1)),
|
||
dim=-1)
|
||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
||
return image_feature
|
||
|
||
|
||
class LlavaOnevisionUtils:
|
||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_onevision/modeling_llava_onevision.py
|
||
|
||
@staticmethod
|
||
def pack_image_features(image_features, image_sizes, image_newline):
|
||
"""
|
||
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
|
||
|
||
Args:
|
||
image_features (`torch.Tensor` of shape `(num_images, num_patches, image_length, embed_dim)`)
|
||
Image feature tensor, each contains all the visual feature of all patches.
|
||
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
|
||
Actual image size of each images (H, W).
|
||
image_newline (`torch.Tensor` of shape `(embed_dim)`)
|
||
New line embedding vector.
|
||
Returns:
|
||
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
|
||
"""
|
||
|
||
IMAGE_SIZE = 384
|
||
PATCH_SIZE = 14
|
||
MAX_NUM_PATCHES = 9
|
||
|
||
new_image_features = []
|
||
for image_idx, image_feature in enumerate(image_features):
|
||
if image_feature.shape[0] > 1:
|
||
base_image_feature = image_feature[0]
|
||
image_feature = image_feature[1:]
|
||
height = width = IMAGE_SIZE // PATCH_SIZE
|
||
if height * width != base_image_feature.shape[0]:
|
||
raise ValueError(
|
||
"The number of patches is not consistent with the image size."
|
||
)
|
||
|
||
IMAGE_GRID_PINPOINTS = [[384, 384], [384, 768], [384, 1152],
|
||
[384, 1536], [384, 1920], [384, 2304],
|
||
[768, 384], [768, 768], [768, 1152],
|
||
[768, 1536], [768, 1920], [768, 2304],
|
||
[1152, 384], [1152, 768], [1152, 1152],
|
||
[1152, 1536],
|
||
[1152, 1920], [1152, 2304], [1536, 384],
|
||
[1536, 768], [1536, 1152], [1536, 1536],
|
||
[1536, 1920], [1536, 2304], [1920, 384],
|
||
[1920, 768], [1920, 1152], [1920, 1536],
|
||
[1920, 1920], [1920, 2304], [2304, 384],
|
||
[2304, 768], [2304, 1152], [2304, 1536],
|
||
[2304, 1920], [2304, 2304]]
|
||
num_patch_width, num_patch_height = LlavaNextUtils.get_anyres_image_grid_shape(
|
||
image_sizes[image_idx][[1, 0]].tolist(), IMAGE_SIZE,
|
||
IMAGE_GRID_PINPOINTS)
|
||
image_feature = image_feature.view(num_patch_height,
|
||
num_patch_width, height,
|
||
width, -1)
|
||
image_feature = image_feature.permute(4, 0, 2, 1,
|
||
3).contiguous()
|
||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||
image_feature = LlavaNextUtils.unpad_image(
|
||
image_feature, image_sizes[image_idx][[1, 0]])
|
||
|
||
channels, curr_height, curr_width = image_feature.shape
|
||
ratio = math.sqrt(curr_height * curr_width /
|
||
(MAX_NUM_PATCHES * height**2))
|
||
if ratio > 1.1:
|
||
image_feature = image_feature[None]
|
||
image_feature = nn.functional.interpolate(
|
||
image_feature,
|
||
[int(curr_height // ratio),
|
||
int(curr_width // ratio)],
|
||
mode="bilinear")[0]
|
||
|
||
image_feature = torch.cat(
|
||
(
|
||
image_feature,
|
||
image_newline[:, None, None].expand(
|
||
*image_feature.shape[:-1], 1).to(
|
||
image_feature.device, image_feature.dtype),
|
||
),
|
||
dim=-1,
|
||
)
|
||
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
||
image_feature = torch.cat((base_image_feature, image_feature),
|
||
dim=0)
|
||
else:
|
||
image_feature = image_feature[0]
|
||
if image_newline is not None:
|
||
image_feature = torch.cat(
|
||
(image_feature, image_newline[None].to(image_feature)),
|
||
dim=0)
|
||
new_image_features.append(image_feature)
|
||
image_features = torch.stack(new_image_features)
|
||
return image_features
|
||
|
||
@staticmethod
|
||
def apply_pooling(image_features):
|
||
IMAGE_SIZE = 384
|
||
PATCH_SIZE = 14
|
||
height = width = IMAGE_SIZE // PATCH_SIZE
|
||
batch_frames, seq_len, dim = image_features.shape
|
||
image_features = image_features.view(batch_frames, height, width, -1)
|
||
image_features = image_features.permute(0, 3, 1, 2).contiguous()
|
||
|
||
height, width = image_features.shape[2:]
|
||
scaled_shape = [math.ceil(height / 2), math.ceil(width / 2)]
|
||
image_features = nn.functional.interpolate(image_features,
|
||
size=scaled_shape,
|
||
mode="bilinear")
|
||
|
||
image_features = image_features.permute(0, 2, 3, 1)
|
||
image_features = image_features.view(batch_frames, -1, dim)
|
||
return image_features
|
||
|
||
|
||
class MultimodalModelRunner:
|
||
|
||
def __init__(self, args):
|
||
self.args = args
|
||
|
||
self.runtime_rank = mpi_rank()
|
||
device_id = self.runtime_rank % torch.cuda.device_count()
|
||
torch.cuda.set_device(device_id)
|
||
self.device = "cuda:%d" % (device_id)
|
||
|
||
self.stream = torch.cuda.Stream(torch.cuda.current_device())
|
||
torch.cuda.set_stream(self.stream)
|
||
|
||
# parse model type from visual engine config
|
||
with open(os.path.join(self.args.visual_engine_dir, "config.json"),
|
||
"r") as f:
|
||
config = json.load(f)
|
||
self.model_type = config['builder_config']['model_type']
|
||
self.vision_precision = config['builder_config']['precision']
|
||
if self.model_type == 'pix2struct':
|
||
self.vision_precision = 'float16'
|
||
self.decoder_llm = not (
|
||
't5' in self.model_type
|
||
or self.model_type in ['nougat', 'pix2struct']
|
||
) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs
|
||
|
||
if self.model_type == 'video-neva':
|
||
self.num_frames = config['builder_config'].get('num_frames', None)
|
||
if self.model_type == "llava_next":
|
||
self.llm_name = AutoConfig.from_pretrained(
|
||
self.args.hf_model_dir).text_config._name_or_path
|
||
if self.model_type == "qwen2_vl":
|
||
hf_config = AutoConfig.from_pretrained(self.args.hf_model_dir)
|
||
self.vision_start_token_id = hf_config.vision_start_token_id
|
||
self.vision_end_token_id = hf_config.vision_end_token_id
|
||
self.vision_token_id = hf_config.vision_token_id
|
||
self.image_token_id = hf_config.image_token_id
|
||
self.video_token_id = hf_config.video_token_id
|
||
self.spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||
self.max_position_embeddings = hf_config.max_position_embeddings
|
||
self.hidden_size = hf_config.hidden_size
|
||
self.num_attention_heads = hf_config.num_attention_heads
|
||
self.rope_theta = hf_config.rope_theta
|
||
if self.model_type == 'llava_onevision':
|
||
self.num_frames = self.args.video_num_frames
|
||
if self.num_frames is None:
|
||
self.num_frames = 8
|
||
assert self.args.video_path is None or self.args.image_path is None
|
||
|
||
if self.model_type == "mllama":
|
||
self.vision_input_names = [
|
||
"pixel_values",
|
||
"aspect_ratio_ids",
|
||
"aspect_ratio_mask",
|
||
]
|
||
self.vision_output_names = [
|
||
"output",
|
||
]
|
||
else:
|
||
self.vision_input_names = ["input"]
|
||
self.vision_output_names = ["output"]
|
||
|
||
if self.decoder_llm:
|
||
if not supports_inflight_batching(self.args.llm_engine_dir):
|
||
logger.warning(
|
||
"The given engine does not support in-flight batching, fallback to python session"
|
||
)
|
||
self.args.use_py_session = True
|
||
|
||
if not PYTHON_BINDINGS and not self.args.use_py_session:
|
||
logger.warning(
|
||
"Python bindings of C++ session is unavailable, fallback to Python session."
|
||
)
|
||
self.args.use_py_session = True
|
||
|
||
args.debug_mode = False
|
||
if args.debug_mode and not self.args.use_py_session:
|
||
logger.warning(
|
||
"Debug mode is not supported in C++ session for now, fallback to Python session."
|
||
)
|
||
self.args.use_py_session = True
|
||
|
||
self.use_py_session = self.args.use_py_session
|
||
if self.model_type == 'qwen2_vl':
|
||
if self.args.use_py_session:
|
||
logger.warning(
|
||
"Qwen2-vl only support C++ session for now, fallback to C++ session."
|
||
)
|
||
self.args.use_py_session = False
|
||
|
||
else:
|
||
self.use_py_session = True
|
||
|
||
self.init_image_encoder()
|
||
self.init_tokenizer()
|
||
self.init_processor()
|
||
self.init_llm()
|
||
|
||
def init_tokenizer(self):
|
||
if self.model_type == 'nougat':
|
||
from transformers import NougatTokenizerFast
|
||
self.tokenizer = NougatTokenizerFast.from_pretrained(
|
||
self.args.hf_model_dir)
|
||
elif self.model_type == 'neva' or self.model_type == 'video-neva':
|
||
from sentencepiece import SentencePieceProcessor
|
||
|
||
sp = SentencePieceProcessor(
|
||
os.path.join(self.args.hf_model_dir, 'tokenizer.model'))
|
||
|
||
class return_obj:
|
||
|
||
def __init__(self, input_ids):
|
||
self.input_ids = input_ids
|
||
|
||
def __getitem__(self, name):
|
||
if name in "input_ids":
|
||
return self.input_ids
|
||
else:
|
||
raise AttributeError(
|
||
f"'return_obj' has no item '{name}'")
|
||
|
||
# sentencepiece does not follow the same interface as HF
|
||
class HFTokenizerInterface():
|
||
|
||
def encode(self, x, return_tensors=None, **kwargs):
|
||
out = sp.encode(x)
|
||
if return_tensors == "pt":
|
||
out = torch.tensor(out)
|
||
return return_obj(out)
|
||
|
||
def __call__(self, x, return_tensors=None, **kwargs):
|
||
return self.encode(x, return_tensors, **kwargs)
|
||
|
||
def decode(self, x, **kwargs):
|
||
return sp.decode(x.tolist())
|
||
|
||
def batch_decode(self, x, **kwargs):
|
||
return self.decode(x, **kwargs)
|
||
|
||
self.tokenizer = HFTokenizerInterface()
|
||
self.tokenizer.eos_token_id = sp.eos_id()
|
||
self.tokenizer.bos_token_id = sp.bos_id()
|
||
self.tokenizer.pad_token_id = sp.pad_id()
|
||
elif self.model_type == 'vila':
|
||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
self.args.hf_model_dir + "/llm",
|
||
use_fast=False,
|
||
use_legacy=False)
|
||
else:
|
||
use_fast = self.model_type in ["phi-3-vision", "internvl"]
|
||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
self.args.hf_model_dir, use_fast=use_fast, use_legacy=False)
|
||
|
||
self.tokenizer.padding_side = "right"
|
||
|
||
def init_processor(self):
|
||
from torchvision import transforms
|
||
|
||
if 'blip2' in self.model_type:
|
||
from transformers import Blip2Processor
|
||
self.processor = Blip2Processor.from_pretrained(
|
||
self.args.hf_model_dir)
|
||
|
||
elif 'nougat' in self.model_type:
|
||
from transformers import NougatProcessor
|
||
self.processor = NougatProcessor.from_pretrained(
|
||
self.args.hf_model_dir)
|
||
|
||
elif 'cogvlm' in self.model_type:
|
||
image_size = 490
|
||
self.transform = transforms.Compose([
|
||
transforms.Resize(
|
||
(image_size, image_size),
|
||
interpolation=transforms.InterpolationMode.BICUBIC),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
||
(0.26862954, 0.26130258, 0.27577711)),
|
||
transforms.ConvertImageDtype(torch.bfloat16),
|
||
])
|
||
|
||
elif 'phi-3-vision' in self.model_type:
|
||
self.processor = AutoProcessor.from_pretrained(
|
||
self.args.hf_model_dir, trust_remote_code=True)
|
||
|
||
elif 'internvl' in self.model_type:
|
||
from transformers import CLIPImageProcessor
|
||
self.processor = CLIPImageProcessor.from_pretrained(
|
||
'OpenGVLab/InternViT-300M-448px'
|
||
) # You can change the InternViT model type according to your InternVL type
|
||
|
||
elif self.model_type == "pix2struct":
|
||
self.processor = AutoProcessor.from_pretrained(
|
||
self.args.hf_model_dir)
|
||
|
||
elif self.model_type == "neva":
|
||
image_size = 384
|
||
self.transform = transforms.Compose([
|
||
transforms.Resize(
|
||
(image_size, image_size),
|
||
interpolation=transforms.InterpolationMode.BICUBIC),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||
transforms.ConvertImageDtype(torch.float32),
|
||
])
|
||
|
||
elif self.model_type == "video-neva":
|
||
pass
|
||
|
||
elif self.model_type == "llava_next":
|
||
self.processor = AutoProcessor.from_pretrained(
|
||
self.args.hf_model_dir, trust_remote_code=True)
|
||
|
||
elif self.model_type in ['llava', 'vila', 'fuyu', 'kosmos-2']:
|
||
if self.model_type == "vila":
|
||
sys.path.append(self.args.hf_model_dir + "/../VILA")
|
||
from llava.mm_utils import process_images
|
||
from llava.model import LlavaLlamaConfig # noqa
|
||
from transformers import AutoModel
|
||
model = AutoModel.from_pretrained(
|
||
self.args.hf_model_dir,
|
||
device_map='auto',
|
||
trust_remote_code=True,
|
||
)
|
||
vision_tower = model.get_vision_tower()
|
||
vision_tower.image_processor
|
||
|
||
def processor(raw_image):
|
||
return process_images(raw_image,
|
||
vision_tower.image_processor,
|
||
model.config).to(model.device,
|
||
dtype=torch.float16)
|
||
|
||
self.processor = processor
|
||
|
||
else:
|
||
self.processor = AutoProcessor.from_pretrained(
|
||
self.args.hf_model_dir)
|
||
|
||
elif self.model_type in ['mllama']:
|
||
self.processor = AutoProcessor.from_pretrained(
|
||
self.args.hf_model_dir)
|
||
|
||
def init_image_encoder(self):
|
||
if self.model_type == "phi-3-vision":
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
self.args.hf_model_dir,
|
||
torch_dtype=torch.float16,
|
||
trust_remote_code=True,
|
||
device_map='cpu')
|
||
self.vision_model = model.model.vision_embed_tokens.to(
|
||
self.device).eval()
|
||
|
||
# Test run vision_model.get_img_features to pre-allocate memory for flash attention
|
||
processor = AutoProcessor.from_pretrained(self.args.hf_model_dir,
|
||
trust_remote_code=True)
|
||
image = processor(text="<|image_1|>",
|
||
images=Image.new('RGB', [10, 10]),
|
||
return_tensors="pt")['pixel_values']
|
||
image = image.flatten(0, 1)
|
||
image = torch.rand(image.shape,
|
||
dtype=str_dtype_to_torch(self.vision_precision),
|
||
device=self.device)
|
||
self.vision_model.get_img_features(image)
|
||
return
|
||
|
||
vision_encoder_path = os.path.join(self.args.visual_engine_dir,
|
||
self.args.visual_engine_name)
|
||
logger.info(f'Loading engine from {vision_encoder_path}')
|
||
with open(vision_encoder_path, 'rb') as f:
|
||
engine_buffer = f.read()
|
||
logger.info(f'Creating session from engine {vision_encoder_path}')
|
||
self.visual_encoder_session = Session.from_serialized_engine(
|
||
engine_buffer)
|
||
if self.model_type in ["llava_next", "llava_onevision"]:
|
||
self.image_newlines = {}
|
||
image_newlines_path = os.path.join(self.args.visual_engine_dir,
|
||
'image_newlines.safetensors')
|
||
with safe_open(image_newlines_path,
|
||
framework="pt",
|
||
device=self.device) as f:
|
||
for k in f.keys():
|
||
self.image_newlines[k] = f.get_tensor(k)
|
||
|
||
def init_llm(self):
|
||
if self.decoder_llm:
|
||
cross_kv_cache_fraction = None
|
||
if self.model_type == 'mllama':
|
||
cross_kv_cache_fraction = self.args.cross_kv_cache_fraction
|
||
if self.use_py_session:
|
||
logger.info(f'Running LLM with Python runner')
|
||
self.model = ModelRunner.from_dir(
|
||
self.args.llm_engine_dir,
|
||
rank=tensorrt_llm.mpi_rank(),
|
||
debug_mode=False,
|
||
stream=self.stream,
|
||
enable_context_fmha_fp32_acc=self.args.
|
||
enable_context_fmha_fp32_acc)
|
||
self.model_config = self.model.session._model_config
|
||
else:
|
||
logger.info(f'Running LLM with C++ runner')
|
||
self.model = ModelRunnerCpp.from_dir(
|
||
self.args.llm_engine_dir,
|
||
rank=tensorrt_llm.mpi_rank(),
|
||
debug_mode=False,
|
||
enable_chunked_context=self.args.enable_chunked_context,
|
||
enable_context_fmha_fp32_acc=self.args.
|
||
enable_context_fmha_fp32_acc,
|
||
kv_cache_free_gpu_memory_fraction=self.args.
|
||
kv_cache_free_gpu_memory_fraction,
|
||
cross_kv_cache_fraction=cross_kv_cache_fraction)
|
||
self.model_config = self.model.model_config
|
||
self.runtime_mapping = self.model.mapping
|
||
else:
|
||
self.model = EncDecModelRunner.from_engine(
|
||
os.path.basename(self.args.hf_model_dir),
|
||
self.args.llm_engine_dir,
|
||
skip_encoder=self.model_type in ['nougat', 'pix2struct'],
|
||
debug_mode=False,
|
||
stream=self.stream,
|
||
enable_context_fmha_fp32_acc=self.args.
|
||
enable_context_fmha_fp32_acc)
|
||
if self.model_type in ['nougat', 'pix2struct']:
|
||
self.model_config = self.model.decoder_model_config
|
||
self.runtime_mapping = self.model.decoder_runtime_mapping
|
||
else:
|
||
self.model_config = self.model.encoder_model_config
|
||
self.runtime_mapping = self.model.encoder_runtime_mapping
|
||
|
||
def video_preprocess(self, video_path):
|
||
from decord import VideoReader
|
||
if isinstance(video_path, str):
|
||
vr = VideoReader(video_path)
|
||
num_frames = self.num_frames
|
||
if num_frames == -1:
|
||
frames = [
|
||
Image.fromarray(frame.asnumpy()[:, :, ::-1]).convert('RGB')
|
||
for frame in vr
|
||
]
|
||
else:
|
||
# equally sliced frames into self.num_frames frames
|
||
# if self.num_frames is greater than the number of frames in the video, we will repeat the last frame
|
||
num_frames = min(num_frames, len(vr))
|
||
indices = np.linspace(0, len(vr) - 1, num=num_frames, dtype=int)
|
||
frames = [
|
||
Image.fromarray(
|
||
vr[idx].asnumpy()[:, :, ::-1]).convert('RGB')
|
||
for idx in indices
|
||
]
|
||
if len(frames) < num_frames:
|
||
frames += [frames[-1]] * (num_frames - len(frames))
|
||
else:
|
||
frames = self.video_path
|
||
|
||
from transformers import CLIPImageProcessor
|
||
processor = CLIPImageProcessor.from_pretrained(
|
||
"openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16)
|
||
frames = processor.preprocess(frames,
|
||
return_tensors="pt")['pixel_values']
|
||
# make dtype consistent with vision encoder
|
||
media_tensors = frames.to(str_dtype_to_torch(
|
||
self.vision_precision)) # [num_frames, 3, H, W]
|
||
return media_tensors.unsqueeze(0) #[1, num_frames, 3, H, W]
|
||
|
||
def preprocess(self, warmup, pre_prompt, post_prompt, image,
|
||
other_vision_inputs):
|
||
if self.model_type == 'kosmos-2':
|
||
input_ids = image['input_ids'].clone()
|
||
image_mask = image["image_embeds_position_mask"]
|
||
image = image['pixel_values']
|
||
input_ids += image_mask * (self.model_config.vocab_size - 4)
|
||
input_ids = input_ids.expand(self.args.batch_size,
|
||
*input_ids.shape[1:])
|
||
length = input_ids.shape[1]
|
||
elif self.model_type == 'phi-3-vision':
|
||
input = image
|
||
image = input['pixel_values']
|
||
image = image.flatten(0, 1)
|
||
elif self.model_type == 'llava_next':
|
||
input = image
|
||
image = input['pixel_values']
|
||
image = image[0]
|
||
image_size = input['image_sizes'][0].cpu()
|
||
elif self.model_type == "qwen2_vl":
|
||
input = image
|
||
image = input['image']
|
||
input_ids = input['input_ids']
|
||
other_vision_inputs['image_grid_thw'].shape[0]
|
||
attention_mask = other_vision_inputs['attention_mask_llm']
|
||
other_vision_inputs.pop('attention_mask_llm')
|
||
image_grid_thw = other_vision_inputs['image_grid_thw']
|
||
other_vision_inputs.pop('image_grid_thw')
|
||
elif self.model_type == 'llava_onevision':
|
||
input = image
|
||
if self.args.video_path is None:
|
||
image = input['pixel_values']
|
||
image = image[0].repeat(self.args.batch_size, 1, 1, 1)
|
||
image_size = input['image_sizes'][0]
|
||
image_size = image_size.repeat(self.args.batch_size, 1).cpu()
|
||
else:
|
||
image = input['pixel_values_videos']
|
||
_, _, c, h, w = image.shape
|
||
image = image.repeat(self.args.batch_size, 1, 1, 1, 1)
|
||
image = image.view(-1, c, h, w)
|
||
|
||
if not warmup:
|
||
profiler.start("Vision")
|
||
|
||
if image is not None:
|
||
if self.model_type == "phi-3-vision":
|
||
visual_features = self.vision_model.get_img_features(
|
||
image).reshape(1, image.shape[0], -1,
|
||
self.vision_model.image_dim_out)
|
||
visual_atts = None
|
||
else:
|
||
visual_features, visual_atts = self.get_visual_features(
|
||
torch.stack(image['image_patches'], dim=0) if
|
||
self.model_type == 'fuyu' else image, other_vision_inputs)
|
||
else:
|
||
visual_features, visual_atts = None, None
|
||
|
||
if not warmup:
|
||
profiler.stop("Vision")
|
||
|
||
if self.model_type == 'fuyu':
|
||
visual_features = visual_features.squeeze()
|
||
input_ids = image['input_ids'].to(torch.int32)
|
||
image_patches_indices = image['image_patches_indices'].to(
|
||
torch.int32)
|
||
|
||
input_ids = input_ids.expand(self.args.batch_size,
|
||
*input_ids.shape[1:])
|
||
image_patches_indices = image_patches_indices.expand(
|
||
self.args.batch_size, *image_patches_indices.shape[1:])
|
||
|
||
input_ids = self.ptuning_setup_fuyu(input_ids,
|
||
image_patches_indices)
|
||
input_ids = torch.stack(input_ids, dim=0).to('cpu')
|
||
length = input_ids.shape[1]
|
||
elif self.model_type == 'qwen2_vl':
|
||
length = input_ids.shape[1]
|
||
input_lengths = torch.IntTensor([length] * self.args.batch_size).to(
|
||
torch.int32)
|
||
input_ids, ptuning_args, mrope_args = self.setup_fake_prompts_qwen2vl(
|
||
visual_features, input_ids, image_grid_thw, attention_mask,
|
||
input_lengths)
|
||
return input_ids, input_lengths, ptuning_args, visual_features, mrope_args
|
||
|
||
elif self.model_type == 'kosmos-2':
|
||
visual_features = visual_features.squeeze()
|
||
elif self.model_type == 'vila':
|
||
input_ids = self.tokenizer_image_token(
|
||
self.args.batch_size, pre_prompt[0] + post_prompt[0],
|
||
self.tokenizer)
|
||
batch_split_prompts = self.split_prompt_by_images(input_ids)
|
||
first_batch_split_prompts = batch_split_prompts[0]
|
||
# compute prompt length + visual length
|
||
length = sum([ids.shape[1] for ids in first_batch_split_prompts])
|
||
if self.args.batch_size == 1 and len(image) > 1:
|
||
# mode 1: multiple image as a whole, flatten visual dims
|
||
length += visual_atts.shape[0] * visual_atts.shape[1]
|
||
else:
|
||
# mode 2: multiple images individually (replicate prompt for each image)
|
||
length += visual_atts.shape[1]
|
||
|
||
input_lengths = torch.IntTensor([length] * self.args.batch_size).to(
|
||
torch.int32)
|
||
input_ids, ptuning_args = self.setup_fake_prompts_vila(
|
||
self.args.batch_size, visual_features,
|
||
first_batch_split_prompts, input_lengths)
|
||
return input_ids, input_lengths, ptuning_args, visual_features
|
||
elif self.model_type == 'phi-3-vision':
|
||
image_sizes = input["image_sizes"]
|
||
visual_features = self.vision_model.hd_feature_transform(
|
||
visual_features, image_sizes)
|
||
input_ids = input["input_ids"].clone()
|
||
input_ids = input_ids.expand(self.args.batch_size,
|
||
*input_ids.shape[1:])
|
||
num_img_tokens = [visual_features.shape[0]]
|
||
input_ids = self.ptuning_setup_phi3(visual_features, input_ids,
|
||
num_img_tokens)
|
||
visual_features = visual_features.unsqueeze(0).repeat(
|
||
self.args.batch_size, 1, 1)
|
||
length = input_ids.shape[1]
|
||
elif self.model_type == 'llava_next':
|
||
visual_features = LlavaNextUtils.rearrange_image_features(
|
||
visual_features, self.image_newlines["image_newline"],
|
||
image_size)
|
||
input_ids = self.ptuning_setup_llava_next(visual_features,
|
||
pre_prompt, post_prompt)
|
||
length = input_ids.shape[1]
|
||
elif self.model_type == 'mllama':
|
||
pre_input_ids = self.tokenizer(pre_prompt,
|
||
return_tensors="pt",
|
||
padding=True).input_ids
|
||
length = pre_input_ids.shape[1]
|
||
post_input_ids = None
|
||
elif self.model_type == 'llava_onevision':
|
||
if self.args.video_path is None:
|
||
visual_features = torch.split(visual_features,
|
||
visual_features.shape[0] //
|
||
self.args.batch_size,
|
||
dim=0)
|
||
visual_features = LlavaOnevisionUtils.pack_image_features(
|
||
visual_features,
|
||
image_size,
|
||
image_newline=self.image_newlines["image_newline"],
|
||
)
|
||
else:
|
||
visual_features = LlavaOnevisionUtils.apply_pooling(
|
||
visual_features)
|
||
visual_features = visual_features.reshape(
|
||
self.args.batch_size,
|
||
self.num_frames * visual_features.shape[1], -1)
|
||
image_newline = self.image_newlines["image_newline"][
|
||
None, None, :].repeat(self.args.batch_size, 1,
|
||
1).to(visual_features.device)
|
||
visual_features = torch.cat((visual_features, image_newline),
|
||
dim=1)
|
||
|
||
pre_input_ids = self.tokenizer(pre_prompt,
|
||
return_tensors="pt",
|
||
padding=True).input_ids
|
||
post_input_ids = self.tokenizer(post_prompt,
|
||
return_tensors="pt",
|
||
padding=True).input_ids
|
||
length = pre_input_ids.shape[1] + visual_features.shape[
|
||
1] + post_input_ids.shape[1]
|
||
else:
|
||
pre_input_ids = self.tokenizer(pre_prompt,
|
||
return_tensors="pt",
|
||
padding=True).input_ids
|
||
if post_prompt[0] is not None:
|
||
post_input_ids = self.tokenizer(post_prompt,
|
||
return_tensors="pt",
|
||
padding=True).input_ids
|
||
if self.model_type == 'video-neva':
|
||
length = pre_input_ids.shape[1] + post_input_ids.shape[
|
||
1] + visual_atts.shape[2] * visual_atts.shape[1]
|
||
elif self.model_type == 'internvl':
|
||
length = pre_input_ids.shape[1] + post_input_ids.shape[
|
||
1] + visual_atts.shape[0] * visual_atts.shape[1]
|
||
else:
|
||
length = pre_input_ids.shape[1] + post_input_ids.shape[
|
||
1] + visual_atts.shape[1]
|
||
else:
|
||
post_input_ids = None
|
||
length = pre_input_ids.shape[1] + visual_atts.shape[1]
|
||
|
||
input_lengths = torch.IntTensor([length] * self.args.batch_size).to(
|
||
torch.int32)
|
||
|
||
if self.model_type in [
|
||
'fuyu', 'kosmos-2', 'phi-3-vision', 'llava_next'
|
||
]:
|
||
return input_ids, input_lengths, [visual_features], visual_features
|
||
|
||
input_ids, ptuning_args = self.setup_fake_prompts(
|
||
visual_features, pre_input_ids, post_input_ids, input_lengths)
|
||
|
||
return input_ids, input_lengths, ptuning_args, visual_features
|
||
|
||
@staticmethod
|
||
def tokenizer_image_token(batch_size,
|
||
prompt,
|
||
tokenizer,
|
||
image_token_index=-200):
|
||
prompt_chunks = [
|
||
tokenizer(chunk).input_ids for chunk in prompt.split("<image>")
|
||
]
|
||
|
||
def insert_separator(X, sep):
|
||
return [
|
||
ele for sublist in zip(X, [sep] * len(X)) for ele in sublist
|
||
][:-1]
|
||
|
||
input_ids = []
|
||
offset = 0
|
||
if (len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0
|
||
and prompt_chunks[0][0] == tokenizer.bos_token_id):
|
||
offset = 1
|
||
input_ids.append(prompt_chunks[0][0])
|
||
|
||
for x in insert_separator(prompt_chunks,
|
||
[image_token_index] * (offset + 1)):
|
||
input_ids.extend(x[offset:])
|
||
|
||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||
input_ids[input_ids == image_token_index] = 0
|
||
input_ids = input_ids.unsqueeze(0).expand(batch_size, -1)
|
||
return input_ids
|
||
|
||
def split_prompt_by_images(self, tensor):
|
||
batch_splits = []
|
||
for batch in tensor:
|
||
# Find indices where value is zero (<image>)
|
||
zero_indices = (batch == 0).nonzero(as_tuple=False).squeeze(0)
|
||
# Add starting point for slicing
|
||
start_idx = 0
|
||
splits = []
|
||
for idx in zero_indices:
|
||
if start_idx != idx: # Ensure not slicing zero-length tensors
|
||
splits.append(batch[start_idx:idx].unsqueeze(0))
|
||
start_idx = idx + 1 # Move start index past the zero
|
||
if start_idx < len(
|
||
batch): # Handle last segment if it's not zero-ending
|
||
splits.append(batch[start_idx:].unsqueeze(0))
|
||
# Remove empty tensors resulting from consecutive zeros
|
||
splits = [split for split in splits if split.numel() > 0]
|
||
batch_splits.append(splits)
|
||
|
||
return batch_splits
|
||
|
||
def prepare_position_ids_for_cogvlm(self, input_ids):
|
||
batch_size = len(input_ids)
|
||
position_ids = torch.arange(input_ids.shape[1])
|
||
position_ids[2:1227] = 2
|
||
position_ids[1227:] = torch.arange(3, input_ids.shape[1] + 1 - 1225)
|
||
|
||
position_ids = position_ids.to(torch.int32).to('cuda')
|
||
input_position_ids = []
|
||
for i in range(batch_size):
|
||
input_position_ids.append(position_ids)
|
||
|
||
return input_position_ids
|
||
|
||
def generate(self,
|
||
pre_prompt,
|
||
post_prompt,
|
||
image,
|
||
decoder_input_ids,
|
||
max_new_tokens,
|
||
warmup=False,
|
||
other_vision_inputs={},
|
||
other_decoder_inputs={}):
|
||
if not warmup:
|
||
profiler.start("Generate")
|
||
if 'qwen2_vl' in self.model_type:
|
||
input_ids, input_lengths, ptuning_args, visual_features, mrope_args = self.preprocess(
|
||
warmup, pre_prompt, post_prompt, image, other_vision_inputs)
|
||
mrope_params = MropeParams(
|
||
mrope_rotary_sin_cos=mrope_args[0],
|
||
mrope_position_deltas=mrope_args[1],
|
||
)
|
||
else:
|
||
input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
|
||
warmup, pre_prompt, post_prompt, image, other_vision_inputs)
|
||
if warmup: return None
|
||
|
||
# use prompt tuning to pass multimodal features
|
||
# model.generate() expects the following params (see layers/embedding.py):
|
||
# args[0]: prompt embedding table, [batch_size, multimodal_len, hidden_size], later flattened to [batch_size * multimodal_len, hidden_size]
|
||
# args[1]: prompt task ids, [batch_size]. in multimodal case, arange(batch_size), i.e. in VILA batching mode 2, each image is treated separately in the batch instead of concated together (although the prompt embedding table has to be concated)
|
||
# args[2]: prompt task vocab size, [1]. assuming all table has the same length, which in multimodal case equals to multimodal_len
|
||
profiler.start("LLM")
|
||
if self.decoder_llm and self.model_type != "mllama":
|
||
end_id = self.tokenizer.eos_token_id
|
||
if 'opt' in self.model_type and 'blip2' in self.model_type:
|
||
# For BLIP2-OPT, model outputs a "\n" at the end.
|
||
# we avoid it by using newline as the end token
|
||
end_id = self.tokenizer.encode("\n",
|
||
add_special_tokens=False)[0]
|
||
|
||
if self.model_type == 'cogvlm':
|
||
input_position_ids = self.prepare_position_ids_for_cogvlm(
|
||
input_ids)
|
||
|
||
batch_size = len(input_ids)
|
||
prompt_tasks = ",".join(
|
||
np.arange(batch_size, dtype=np.int32).astype(str))
|
||
prompt_table = torch.stack([ptuning_args[0]])
|
||
prompt_table = prompt_table.view(batch_size, -1,
|
||
prompt_table.shape[-1])
|
||
|
||
output_ids = self.model.generate(
|
||
input_ids,
|
||
input_position_ids=input_position_ids
|
||
if self.model_type == 'cogvlm' else None,
|
||
mrope_params=mrope_params
|
||
if self.model_type == 'qwen2_vl' else None,
|
||
sampling_config=None,
|
||
prompt_table=prompt_table,
|
||
prompt_tasks=prompt_tasks,
|
||
max_new_tokens=max_new_tokens,
|
||
end_id=end_id,
|
||
pad_id=self.tokenizer.pad_token_id
|
||
if self.tokenizer.pad_token_id is not None else
|
||
self.tokenizer.all_special_ids[0],
|
||
top_k=self.args.top_k,
|
||
top_p=self.args.top_p,
|
||
temperature=self.args.temperature,
|
||
repetition_penalty=self.args.repetition_penalty,
|
||
num_beams=self.args.num_beams,
|
||
output_sequence_lengths=False,
|
||
return_dict=False)
|
||
elif self.model_type == "mllama":
|
||
# When image is passed:
|
||
# the shape of visual_features is [bs, 1, 4, 1025, hidden_size]
|
||
# the shape of cross_attention_mask is [bs, decode_input_len, 1, 4]
|
||
# When image is None, create dummy visual_features and cross_attention_mask
|
||
if visual_features is None:
|
||
visual_features = torch.zeros([
|
||
self.args.batch_size, 1, 4, 1,
|
||
self.model_config.hidden_size * self.runtime_mapping.tp_size
|
||
],
|
||
dtype=self.model.dtype,
|
||
device=self.device)
|
||
dummy_cross_attention_mask = torch.zeros(
|
||
[self.args.batch_size, input_ids.shape[1], 1, 4],
|
||
dtype=bool,
|
||
device=self.device)
|
||
skip_cross_attn_blocks = torch.ones([1],
|
||
dtype=torch.bool,
|
||
device='cpu')
|
||
else:
|
||
skip_cross_attn_blocks = torch.zeros([1],
|
||
dtype=torch.bool,
|
||
device='cpu')
|
||
|
||
visual_features = visual_features.to(self.model.dtype).chunk(
|
||
self.args.batch_size, dim=0)
|
||
encoder_input_features = []
|
||
cross_attention_masks = []
|
||
encoder_output_lengths = []
|
||
for batch_idx in range(self.args.batch_size):
|
||
visual_feature = visual_features[batch_idx]
|
||
num_vision_tokens = visual_feature.shape[3]
|
||
visual_feature = visual_feature.reshape(
|
||
[-1, visual_feature.shape[-1]])
|
||
encoder_max_input_length = visual_feature.shape[0]
|
||
encoder_input_lengths = torch.IntTensor(
|
||
[encoder_max_input_length]).to(visual_feature.device)
|
||
|
||
# prepare cross_attention_mask of context phase
|
||
if 'cross_attention_mask' in other_decoder_inputs:
|
||
cross_attention_mask = other_decoder_inputs[
|
||
'cross_attention_mask'][batch_idx]
|
||
else:
|
||
cross_attention_mask = dummy_cross_attention_mask[batch_idx]
|
||
text_total_length, *_ = cross_attention_mask.shape
|
||
cross_attention_mask = cross_attention_mask.repeat_interleave(
|
||
num_vision_tokens, dim=2)
|
||
cross_attention_mask = cross_attention_mask.view(
|
||
text_total_length, -1)
|
||
cross_attention_mask = cross_attention_mask.unsqueeze(1)
|
||
cross_attention_mask = cross_attention_mask.to(
|
||
visual_feature.device).to(torch.bool).reshape(
|
||
[-1, cross_attention_mask.shape[-1]])
|
||
|
||
# prepare cross_attention_mask for generation phase and concat them
|
||
tmp_mask = [cross_attention_mask] + [
|
||
cross_attention_mask[-1:, :] for _ in range(max_new_tokens)
|
||
]
|
||
cross_attention_mask = torch.concat(tmp_mask)
|
||
|
||
encoder_input_features.append(visual_feature)
|
||
cross_attention_masks.append(cross_attention_mask)
|
||
encoder_output_lengths.append(encoder_input_lengths)
|
||
|
||
outputs = self.model.generate(
|
||
batch_input_ids=input_ids,
|
||
encoder_input_ids=None,
|
||
encoder_input_features=encoder_input_features,
|
||
encoder_output_lengths=encoder_output_lengths,
|
||
cross_attention_masks=cross_attention_masks,
|
||
max_new_tokens=max_new_tokens,
|
||
# max_attention_window_size=args.max_attention_window_size,
|
||
# sink_token_length=args.sink_token_length,
|
||
end_id=self.tokenizer.eos_token_id,
|
||
pad_id=self.tokenizer.pad_token_id,
|
||
temperature=self.args.temperature,
|
||
top_k=self.args.top_k,
|
||
top_p=self.args.top_p,
|
||
num_beams=self.args.num_beams,
|
||
# length_penalty=args.length_penalty,
|
||
# early_stopping=args.early_stopping,
|
||
repetition_penalty=self.args.repetition_penalty,
|
||
# presence_penalty=args.presence_penalty,
|
||
# frequency_penalty=args.frequency_penalty,
|
||
# stop_words_list=stop_words_list,
|
||
# bad_words_list=bad_words_list,
|
||
# output_cum_log_probs=(args.output_cum_log_probs_npy != None),
|
||
# output_log_probs=(args.output_log_probs_npy != None),
|
||
# random_seed=args.random_seed,
|
||
# lora_uids=args.lora_task_uids,
|
||
# prompt_table=args.prompt_table_path,
|
||
# prompt_tasks=args.prompt_tasks,
|
||
# streaming=args.streaming,
|
||
output_sequence_lengths=True,
|
||
# no_repeat_ngram_size=self.args.no_repeat_ngram_size,
|
||
return_dict=True,
|
||
# medusa_choices=args.medusa_choices,
|
||
# return_all_generated_tokens=args.return_all_generated_tokens,
|
||
# input_token_extra_ids=input_token_extra_ids,
|
||
encoder_max_input_length=encoder_max_input_length,
|
||
skip_cross_attn_blocks=skip_cross_attn_blocks,
|
||
)
|
||
if mpi_rank() == 0:
|
||
output_ids = outputs["output_ids"]
|
||
else:
|
||
if self.model_type in ['nougat', 'pix2struct']:
|
||
# Trim encoder input_ids to match visual features shape
|
||
ids_shape = (self.args.batch_size, visual_features.shape[1])
|
||
if self.model_type == 'nougat':
|
||
input_ids = torch.zeros(ids_shape, dtype=torch.int32)
|
||
elif self.model_type == 'pix2struct':
|
||
input_ids = torch.ones(ids_shape, dtype=torch.int32)
|
||
|
||
output_ids = self.model.generate(
|
||
input_ids,
|
||
decoder_input_ids,
|
||
max_new_tokens,
|
||
num_beams=self.args.num_beams,
|
||
bos_token_id=self.tokenizer.bos_token_id,
|
||
pad_token_id=self.tokenizer.pad_token_id,
|
||
eos_token_id=self.tokenizer.eos_token_id,
|
||
debug_mode=False,
|
||
prompt_embedding_table=ptuning_args[0],
|
||
prompt_tasks=ptuning_args[1],
|
||
prompt_vocab_size=ptuning_args[2])
|
||
|
||
# Reset input_lengths to match decoder_input_ids
|
||
input_lengths = torch.ones(input_lengths.shape,
|
||
dtype=input_lengths.dtype)
|
||
profiler.stop("LLM")
|
||
|
||
if mpi_rank() == 0:
|
||
# Extract a list of tensors of shape beam_width x output_ids.
|
||
output_beams_list = [
|
||
self.tokenizer.batch_decode(
|
||
output_ids[batch_idx, :, input_lengths[batch_idx]:],
|
||
skip_special_tokens=True)
|
||
for batch_idx in range(self.args.batch_size)
|
||
]
|
||
|
||
stripped_text = [[
|
||
output_beams_list[batch_idx][beam_idx].strip()
|
||
for beam_idx in range(self.args.num_beams)
|
||
] for batch_idx in range(self.args.batch_size)]
|
||
profiler.stop("Generate")
|
||
return stripped_text
|
||
else:
|
||
profiler.stop("Generate")
|
||
return None
|
||
|
||
def get_visual_features(self, image, other_vision_inputs):
|
||
visual_features = {
|
||
self.vision_input_names[0]:
|
||
image.to(str_dtype_to_torch(self.vision_precision)),
|
||
}
|
||
if self.model_type == "qwen2_vl":
|
||
other_vision_inputs['attention_mask'] = other_vision_inputs[
|
||
'attention_mask'].to(str_dtype_to_torch(self.vision_precision))
|
||
for key, tensor in other_vision_inputs.items():
|
||
visual_features.update({key: tensor})
|
||
|
||
tensor_info = [
|
||
TensorInfo(self.vision_input_names[0],
|
||
str_dtype_to_trt(self.vision_precision), image.shape),
|
||
]
|
||
for key, tensor in other_vision_inputs.items():
|
||
tensor_info.append(
|
||
TensorInfo(key, torch_dtype_to_trt(tensor.dtype), tensor.shape))
|
||
|
||
visual_output_info = self.visual_encoder_session.infer_shapes(
|
||
tensor_info)
|
||
self.visual_encoder_session.set_shapes(visual_features)
|
||
visual_outputs = {
|
||
t.name: torch.empty(tuple(t.shape),
|
||
dtype=trt_dtype_to_torch(t.dtype),
|
||
device=image.device)
|
||
for t in visual_output_info
|
||
}
|
||
|
||
ok = self.visual_encoder_session.run(visual_features, visual_outputs,
|
||
self.stream.cuda_stream)
|
||
assert ok, "Runtime execution failed for vision encoder session"
|
||
self.stream.synchronize()
|
||
|
||
image_embeds = visual_outputs[self.vision_output_names[0]]
|
||
image_atts = torch.ones(image_embeds.size()[:-1],
|
||
dtype=torch.long).to(image.device)
|
||
|
||
return image_embeds, image_atts
|
||
|
||
def setup_fake_prompts_vila(self, batch_size, visual_features,
|
||
split_input_ids, input_lengths):
|
||
# visual_features (num_images, feature_len, token_embed)
|
||
# Assemble fake prompts which points to image embedding actually
|
||
fake_prompt_counter = self.model_config.vocab_size
|
||
if batch_size == 1:
|
||
# only check for multi-image inference (mode 1)
|
||
assert len(visual_features) <= len(
|
||
split_input_ids
|
||
), "Unexpected number of visual features. Please check #<image> in prompt and the #image files."
|
||
|
||
input_ids = []
|
||
if batch_size == 1:
|
||
# mode 1: multiple image as a whole, concat all prompts together, <pre><image1><inter><image2>...<post>
|
||
input_ids = [split_input_ids[0]]
|
||
for idx, visual_feature in enumerate(visual_features):
|
||
fake_prompt_id = torch.arange(
|
||
fake_prompt_counter,
|
||
fake_prompt_counter + visual_feature.shape[0])
|
||
fake_prompt_counter += visual_feature.shape[0]
|
||
fake_prompt_id = fake_prompt_id.unsqueeze(0)
|
||
input_ids.append(fake_prompt_id)
|
||
# in case no inter or post prompt
|
||
if len(split_input_ids) > idx + 1:
|
||
input_ids.append(split_input_ids[idx + 1])
|
||
|
||
elif batch_size > 1:
|
||
# mode 2: each image have individual prompt, <pre><image><post>
|
||
for idx, visual_feature in enumerate(visual_features):
|
||
input_ids.append(split_input_ids[0])
|
||
fake_prompt_id = torch.arange(
|
||
fake_prompt_counter,
|
||
fake_prompt_counter + visual_feature.shape[0])
|
||
fake_prompt_id = fake_prompt_id.unsqueeze(0)
|
||
input_ids.append(fake_prompt_id)
|
||
if len(split_input_ids) > 1:
|
||
input_ids.append(split_input_ids[1])
|
||
|
||
input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
|
||
input_ids = input_ids.reshape(batch_size, -1)
|
||
|
||
if self.decoder_llm or self.runtime_mapping.is_first_pp_rank():
|
||
ptuning_args = self.ptuning_setup(visual_features, input_ids,
|
||
input_lengths)
|
||
else:
|
||
ptuning_args = [None, None, None]
|
||
|
||
return input_ids, ptuning_args
|
||
|
||
def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids,
|
||
input_lengths):
|
||
# Assemble fake prompts which points to image embedding actually
|
||
if hasattr(self, 'num_frames') and (visual_features.shape[1]
|
||
== self.num_frames):
|
||
visual_features = visual_features.view(visual_features.shape[0], -1,
|
||
visual_features.shape[-1])
|
||
|
||
if visual_features is not None:
|
||
if self.use_py_session:
|
||
# Non-IFB Mode(used in python session): All requests in a batch have their prompt_table concatenated in
|
||
# a shape of (bs*vision_embedding_len, vision_hidden). So only one fake_prompt_id is needed for the
|
||
# entire batch, with values from 0 to bs * vision_embedding_len-1.
|
||
fake_prompt_id = torch.arange(
|
||
self.model_config.vocab_size, self.model_config.vocab_size +
|
||
visual_features.shape[0] * visual_features.shape[1])
|
||
fake_prompt_id = fake_prompt_id.reshape(
|
||
visual_features.shape[0], visual_features.shape[1])
|
||
else:
|
||
# IFB Mode(used in c++ session): Each request's prompt_table is independent and requires a fake_prompt_id
|
||
# for each request, with values ranging from 0 to vision_embedding_len-1.
|
||
fake_prompt_id = torch.arange(
|
||
self.model_config.vocab_size,
|
||
self.model_config.vocab_size + visual_features.shape[1])
|
||
fake_prompt_id = fake_prompt_id.repeat(visual_features.shape[0],
|
||
1)
|
||
|
||
if 'internvl' in self.model_type:
|
||
fake_prompt_id = fake_prompt_id.reshape(1, -1)
|
||
|
||
if 'cogvlm' in self.model_type:
|
||
input_ids = torch.cat(
|
||
[pre_input_ids[:, 0:1], fake_prompt_id, pre_input_ids[:, 1:]],
|
||
dim=1).contiguous().to(torch.int32)
|
||
elif self.model_type == 'mllama':
|
||
input_ids = pre_input_ids.contiguous().to(torch.int32)
|
||
else:
|
||
if post_input_ids is not None:
|
||
input_ids = [pre_input_ids, fake_prompt_id, post_input_ids]
|
||
else:
|
||
input_ids = [fake_prompt_id, pre_input_ids]
|
||
input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
|
||
|
||
if (self.decoder_llm or self.runtime_mapping.is_first_pp_rank()
|
||
) and self.model_type != "mllama":
|
||
ptuning_args = self.ptuning_setup(visual_features, input_ids,
|
||
input_lengths)
|
||
else:
|
||
ptuning_args = [None, None, None]
|
||
|
||
return input_ids, ptuning_args
|
||
|
||
def get_rope_index(
|
||
self,
|
||
input_ids: torch.LongTensor,
|
||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||
attention_mask: Optional[torch.Tensor] = None,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
||
|
||
Explanation:
|
||
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
||
|
||
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
|
||
Examples:
|
||
input_ids: [T T T T T], here T is for text.
|
||
temporal position_ids: [0, 1, 2, 3, 4]
|
||
height position_ids: [0, 1, 2, 3, 4]
|
||
width position_ids: [0, 1, 2, 3, 4]
|
||
|
||
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
||
and 1D rotary position embeddin for text part.
|
||
Examples:
|
||
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
|
||
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
||
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
|
||
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
||
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
||
text temporal position_ids: [3, 4, 5, 6, 7]
|
||
text height position_ids: [3, 4, 5, 6, 7]
|
||
text width position_ids: [3, 4, 5, 6, 7]
|
||
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
||
|
||
Args:
|
||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
||
it.
|
||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||
The temporal, height and width of feature shape of each image in LLM.
|
||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||
The temporal, height and width of feature shape of each video in LLM.
|
||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||
|
||
- 1 for tokens that are **not masked**,
|
||
- 0 for tokens that are **masked**.
|
||
|
||
Returns:
|
||
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
||
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
||
"""
|
||
spatial_merge_size = self.spatial_merge_size
|
||
image_token_id = self.image_token_id
|
||
video_token_id = self.video_token_id
|
||
vision_start_token_id = self.vision_start_token_id
|
||
mrope_position_deltas = []
|
||
if image_grid_thw is not None or video_grid_thw is not None:
|
||
total_input_ids = input_ids
|
||
position_ids = torch.ones(3,
|
||
input_ids.shape[0],
|
||
input_ids.shape[1],
|
||
dtype=input_ids.dtype,
|
||
device=input_ids.device)
|
||
image_index, video_index = 0, 0
|
||
for i, input_ids in enumerate(total_input_ids):
|
||
if attention_mask is not None:
|
||
input_ids = input_ids[attention_mask[i] == 1]
|
||
image_nums, video_nums = 0, 0
|
||
vision_start_indices = torch.argwhere(
|
||
input_ids == vision_start_token_id).squeeze(1)
|
||
vision_tokens = input_ids[vision_start_indices + 1]
|
||
image_nums = (vision_tokens == image_token_id).sum()
|
||
video_nums = (vision_tokens == video_token_id).sum()
|
||
input_tokens = input_ids.tolist()
|
||
llm_pos_ids_list: list = []
|
||
st = 0
|
||
remain_images, remain_videos = image_nums, video_nums
|
||
for _ in range(image_nums + video_nums):
|
||
if image_token_id in input_tokens and remain_images > 0:
|
||
ed_image = input_tokens.index(image_token_id, st)
|
||
else:
|
||
ed_image = len(input_tokens) + 1
|
||
if video_token_id in input_tokens and remain_videos > 0:
|
||
ed_video = input_tokens.index(video_token_id, st)
|
||
else:
|
||
ed_video = len(input_tokens) + 1
|
||
if ed_image < ed_video:
|
||
t, h, w = (
|
||
image_grid_thw[image_index][0],
|
||
image_grid_thw[image_index][1],
|
||
image_grid_thw[image_index][2],
|
||
)
|
||
image_index += 1
|
||
remain_images -= 1
|
||
ed = ed_image
|
||
else:
|
||
t, h, w = (
|
||
video_grid_thw[video_index][0],
|
||
video_grid_thw[video_index][1],
|
||
video_grid_thw[video_index][2],
|
||
)
|
||
video_index += 1
|
||
remain_videos -= 1
|
||
ed = ed_video
|
||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||
t.item(),
|
||
h.item() // spatial_merge_size,
|
||
w.item() // spatial_merge_size,
|
||
)
|
||
text_len = ed - st
|
||
|
||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||
llm_pos_ids_list) > 0 else 0
|
||
llm_pos_ids_list.append(
|
||
torch.arange(text_len).view(1, -1).expand(3, -1) +
|
||
st_idx)
|
||
|
||
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
|
||
-1, llm_grid_h * llm_grid_w).flatten()
|
||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||
llm_grid_t, -1, llm_grid_w).flatten()
|
||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||
llm_grid_t, llm_grid_h, -1).flatten()
|
||
llm_pos_ids_list.append(
|
||
torch.stack([t_index, h_index, w_index]) + text_len +
|
||
st_idx)
|
||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||
|
||
if st < len(input_tokens):
|
||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||
llm_pos_ids_list) > 0 else 0
|
||
text_len = len(input_tokens) - st
|
||
llm_pos_ids_list.append(
|
||
torch.arange(text_len).view(1, -1).expand(3, -1) +
|
||
st_idx)
|
||
|
||
llm_positions = torch.cat(llm_pos_ids_list,
|
||
dim=1).reshape(3, -1)
|
||
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
|
||
position_ids.device)
|
||
mrope_position_deltas.append(llm_positions.max() + 1 -
|
||
len(total_input_ids[i]))
|
||
mrope_position_deltas = torch.tensor(
|
||
mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
||
return position_ids, mrope_position_deltas
|
||
else:
|
||
if attention_mask is not None:
|
||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(
|
||
input_ids.device)
|
||
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
|
||
-1, keepdim=True)[0]
|
||
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[
|
||
-1]
|
||
else:
|
||
position_ids = (torch.arange(input_ids.shape[1],
|
||
device=input_ids.device).view(
|
||
1, 1, -1).expand(
|
||
3, input_ids.shape[0], -1))
|
||
mrope_position_deltas = torch.zeros(
|
||
[input_ids.shape[0], 1],
|
||
device=input_ids.device,
|
||
dtype=input_ids.dtype,
|
||
)
|
||
|
||
return position_ids, mrope_position_deltas
|
||
|
||
def setup_fake_prompts_qwen2vl(self, visual_features, input_ids,
|
||
vision_grid_thws, attention_mask,
|
||
input_lengths):
|
||
|
||
visual_features = torch.unsqueeze(visual_features, 0)
|
||
|
||
#generate mrope_params
|
||
mrope_position_ids, mrope_position_deltas = self.get_rope_index(
|
||
input_ids,
|
||
image_grid_thw=vision_grid_thws,
|
||
video_grid_thw=None,
|
||
attention_mask=attention_mask,
|
||
)
|
||
|
||
mask = (input_ids == self.image_token_id) | (
|
||
input_ids == self.vision_token_id) | (input_ids
|
||
== self.video_token_id)
|
||
indices = torch.nonzero(mask, as_tuple=False)
|
||
value = self.model_config.vocab_size
|
||
for idx in indices:
|
||
input_ids[tuple(idx)] = value
|
||
value += 1
|
||
|
||
if self.decoder_llm or self.runtime_mapping.is_first_pp_rank():
|
||
ptuning_args = self.ptuning_setup(visual_features, input_ids,
|
||
input_lengths)
|
||
else:
|
||
ptuning_args = [None, None, None]
|
||
|
||
mrope_position_ids = mrope_position_ids
|
||
mrope_position_deltas = mrope_position_deltas
|
||
mrope_position_ids = mrope_position_ids.transpose(1, 0)
|
||
max_position_embeddings = int(self.max_position_embeddings)
|
||
rotary_embedding_dim = int(self.hidden_size / self.num_attention_heads)
|
||
mrope_position_ids_padding = torch.zeros(mrope_position_ids.shape[:-1] +
|
||
(max_position_embeddings, ),
|
||
dtype=torch.int32)
|
||
mrope_position_ids_padding[:, :, :mrope_position_ids.
|
||
shape[-1]] = mrope_position_ids
|
||
|
||
rotary_embedding_base = float(self.rope_theta)
|
||
rotary_embedding_scale = float(1.0)
|
||
rotary_embedding_scale_type = RotaryScalingType.mrope
|
||
rotary_embedding_scaling = None
|
||
inv_freq, rotary_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
|
||
max_position_embeddings, rotary_embedding_dim,
|
||
rotary_embedding_base, rotary_embedding_scale,
|
||
rotary_embedding_scale_type, rotary_embedding_scaling)
|
||
rotary_cos_sin = rotary_cos_sin.reshape(max_position_embeddings,
|
||
int(rotary_embedding_dim / 2),
|
||
2)
|
||
rotary_cos_sin = torch.from_numpy(rotary_cos_sin)
|
||
cos_ori = rotary_cos_sin[:, :, 0]
|
||
sin_ori = rotary_cos_sin[:, :, 1]
|
||
cos = cos_ori[mrope_position_ids_padding]
|
||
sin = sin_ori[mrope_position_ids_padding]
|
||
|
||
mrope_section = [16, 24, 24]
|
||
unsqueeze_dim = -1
|
||
cos = torch.cat([
|
||
m[:, i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))
|
||
],
|
||
dim=-1).unsqueeze(unsqueeze_dim)
|
||
sin = torch.cat([
|
||
m[:, i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))
|
||
],
|
||
dim=-1).unsqueeze(unsqueeze_dim)
|
||
concat_cos_sin = np.concatenate((cos, sin), axis=-1)
|
||
concat_cos_sin = concat_cos_sin.reshape(concat_cos_sin.shape[0], -1)
|
||
concat_cos_sin = torch.from_numpy(concat_cos_sin)
|
||
|
||
mrope_args = [concat_cos_sin, mrope_position_deltas]
|
||
|
||
return input_ids, ptuning_args, mrope_args
|
||
|
||
def ptuning_setup_fuyu(self, input_ids, image_patches_indices):
|
||
res_input_ids = []
|
||
for cur_input_ids, cur_image_patches_indices in zip(
|
||
input_ids, image_patches_indices):
|
||
# Truncate input_ids to the length of image_patches_indices
|
||
cur_image_patches_indices = cur_image_patches_indices[:len(
|
||
cur_input_ids)]
|
||
# Get ids of the image_patches
|
||
non_zero_mask = cur_image_patches_indices != -1
|
||
# Replace input_ids with image_patches_indices values (where the patches are placed)
|
||
cur_input_ids = cur_input_ids.masked_scatter(
|
||
non_zero_mask,
|
||
cur_image_patches_indices[non_zero_mask] +
|
||
self.model_config.vocab_size,
|
||
)
|
||
res_input_ids.append(cur_input_ids)
|
||
return res_input_ids
|
||
|
||
def ptuning_setup_llava_next(self, visual_features, pre_prompt,
|
||
post_prompt):
|
||
input_ids = []
|
||
fake_prompt_ids = list(
|
||
range(self.model_config.vocab_size,
|
||
self.model_config.vocab_size + visual_features.shape[0]))
|
||
input_ids = self.tokenizer.encode(
|
||
pre_prompt[0]) + fake_prompt_ids + self.tokenizer.encode(
|
||
post_prompt[0])[self.tokenizer.add_bos_token:]
|
||
input_ids = [input_ids] * len(pre_prompt)
|
||
input_ids = torch.tensor(input_ids)
|
||
return input_ids
|
||
|
||
def ptuning_setup_phi3(self, visual_features, input_ids, num_img_tokens):
|
||
fake_prompt_id = torch.arange(
|
||
self.model_config.vocab_size,
|
||
self.model_config.vocab_size + visual_features.shape[0])
|
||
MAX_INPUT_ID = int(1e9)
|
||
positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID),
|
||
as_tuple=False)
|
||
idx = 0
|
||
for _, cnt in enumerate(num_img_tokens):
|
||
input_ids[positions[idx, 0], positions[idx, 1]:positions[idx, 1] +
|
||
cnt] = fake_prompt_id[idx:idx + cnt]
|
||
idx += cnt
|
||
return input_ids
|
||
|
||
def ptuning_setup(self, prompt_table, input_ids, input_lengths):
|
||
hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size
|
||
if prompt_table is not None:
|
||
task_vocab_size = torch.tensor(
|
||
[prompt_table.shape[1]],
|
||
dtype=torch.int32,
|
||
).cuda()
|
||
prompt_table = prompt_table.view(
|
||
(prompt_table.shape[0] * prompt_table.shape[1],
|
||
prompt_table.shape[2]))
|
||
|
||
assert prompt_table.shape[
|
||
1] == hidden_size, "Prompt table dimensions do not match hidden size"
|
||
|
||
if hasattr(self.model_config, 'dtype'):
|
||
prompt_table = prompt_table.cuda().to(
|
||
dtype=str_dtype_to_torch(self.model_config.dtype))
|
||
else:
|
||
prompt_table = prompt_table.cuda().to(dtype=self.model.dtype)
|
||
else:
|
||
prompt_table = torch.empty([1, hidden_size]).cuda()
|
||
task_vocab_size = torch.zeros([1]).cuda()
|
||
|
||
remove_input_padding = self.model_config.remove_input_padding if hasattr(
|
||
self.model_config,
|
||
'remove_input_padding') else self.model_config.use_packed_input
|
||
if remove_input_padding:
|
||
tasks = torch.zeros([torch.sum(input_lengths)],
|
||
dtype=torch.int32).cuda()
|
||
if self.decoder_llm: tasks = tasks.unsqueeze(0)
|
||
else:
|
||
tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()
|
||
|
||
return [prompt_table, tasks, task_vocab_size]
|
||
|
||
def load_test_image(self):
|
||
|
||
def load_images(image_paths):
|
||
if isinstance(image_paths, str):
|
||
image_paths = [image_paths]
|
||
images = []
|
||
for image_path in image_paths:
|
||
if image_path.startswith("http") or image_path.startswith(
|
||
"https"):
|
||
logger.info(f"downloading image from url {image_path}")
|
||
response = requests.get(image_path, timeout=5)
|
||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
||
else:
|
||
image = Image.open(image_path).convert("RGB")
|
||
images.append(image)
|
||
return images if len(images) > 1 else images[0]
|
||
|
||
if "vila" in self.model_type:
|
||
if self.args.image_path is None:
|
||
img_urls = [
|
||
'https://github.com/Efficient-Large-Model/VILA/raw/main/demo_images/av.png',
|
||
'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png'
|
||
] * 4
|
||
img_urls = img_urls[:self.args.batch_size]
|
||
self.args.image_path = ",".join(img_urls)
|
||
images = load_images(img_urls)
|
||
else:
|
||
images = load_images(
|
||
self.args.image_path.split(self.args.path_sep))
|
||
|
||
elif "nougat" in self.model_type:
|
||
filepath = hf_hub_download(
|
||
repo_id="hf-internal-testing/fixtures_docvqa",
|
||
filename="nougat_paper.png",
|
||
repo_type="dataset")
|
||
images = Image.open(filepath)
|
||
elif "fuyu" in self.model_type:
|
||
filepath = hf_hub_download(repo_id="adept/fuyu-8b",
|
||
filename="skateboard.png",
|
||
repo_type='model')
|
||
images = Image.open(filepath)
|
||
elif "kosmos" in self.model_type:
|
||
if self.args.image_path is None:
|
||
self.args.image_path = 'https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.png'
|
||
images = load_images(self.args.image_path)
|
||
elif "pix2struct" in self.model_type:
|
||
if self.args.image_path is None:
|
||
self.args.image_path = 'https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_40963.png'
|
||
images = load_images(self.args.image_path)
|
||
elif "video-neva" in self.model_type:
|
||
images = self.args.video_path
|
||
elif "internvl" in self.model_type:
|
||
if self.args.image_path is None:
|
||
img_url = 'https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/examples/image1.jpg'
|
||
images = Image.open(
|
||
requests.get(img_url, stream=True,
|
||
timeout=5).raw).convert('RGB')
|
||
else:
|
||
images = Image.open(self.args.image_path).convert('RGB')
|
||
elif "qwen2_vl" in self.model_type:
|
||
if self.args.image_path is None:
|
||
img_url = 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'
|
||
images = Image.open(
|
||
requests.get(img_url, stream=True,
|
||
timeout=5).raw).convert('RGB')
|
||
images = images.resize(
|
||
(images.size[0] // 2, images.size[1] // 2))
|
||
else:
|
||
images = Image.open(self.args.image_path).convert('RGB')
|
||
elif "llava_onevision" in self.model_type and self.args.video_path is not None:
|
||
if self.args.video_path == 'llava-onevision-accuracy':
|
||
self.args.video_path = hf_hub_download(
|
||
repo_id="raushan-testing-hf/videos-test",
|
||
filename="sample_demo_1.mp4",
|
||
repo_type="dataset")
|
||
import av
|
||
with av.open(self.args.video_path) as container:
|
||
total_frames = container.streams.video[0].frames
|
||
assert total_frames >= self.num_frames
|
||
indices = np.arange(0, total_frames,
|
||
total_frames / self.num_frames).astype(int)
|
||
frames = []
|
||
container.seek(0)
|
||
start_index = indices[0]
|
||
end_index = indices[-1]
|
||
for i, frame in enumerate(container.decode(video=0)):
|
||
if i > end_index:
|
||
break
|
||
if i >= start_index and i in indices:
|
||
frames.append(frame)
|
||
images = np.stack(
|
||
[x.to_ndarray(format="rgb24") for x in frames])
|
||
images = torch.tensor(images)
|
||
else:
|
||
if self.args.image_path is None and self.model_type != 'mllama':
|
||
self.args.image_path = 'https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png'
|
||
images = load_images(self.args.image_path
|
||
) if self.args.image_path is not None else None
|
||
return images
|
||
|
||
def setup_inputs(self, input_text, raw_image):
|
||
from ..tools.multimodal_builder import compute_rotary_pos_emb
|
||
other_vision_inputs = {}
|
||
other_decoder_inputs = {}
|
||
|
||
if 'blip2' in self.model_type:
|
||
image = self.processor(raw_image, input_text,
|
||
return_tensors="pt")['pixel_values']
|
||
if input_text is None:
|
||
input_text = "Question: which city is this? Answer:"
|
||
pre_prompt = input_text
|
||
post_prompt = None
|
||
elif 'qwen2_vl' in self.model_type:
|
||
from qwen_vl_utils import process_vision_info
|
||
from transformers.models.qwen2_vl.modeling_qwen2_vl import \
|
||
VisionRotaryEmbedding
|
||
processor = AutoProcessor.from_pretrained(self.args.hf_model_dir)
|
||
hf_config = AutoConfig.from_pretrained(self.args.hf_model_dir)
|
||
if input_text is None:
|
||
input_text = "Question: Describe this image. Answer:"
|
||
messages = [{
|
||
"role":
|
||
"user",
|
||
"content": [
|
||
{
|
||
"type": "image",
|
||
"image": raw_image,
|
||
},
|
||
{
|
||
"type": "text",
|
||
"text": input_text
|
||
},
|
||
],
|
||
}]
|
||
|
||
text = processor.apply_chat_template(messages,
|
||
tokenize=False,
|
||
add_generation_prompt=True)
|
||
image_inputs, video_inputs = process_vision_info(messages)
|
||
inputs = processor(
|
||
text=[text],
|
||
images=image_inputs,
|
||
videos=video_inputs,
|
||
padding=True,
|
||
return_tensors="pt",
|
||
)
|
||
inputs = inputs.to(self.device)
|
||
image = inputs['pixel_values']
|
||
image_grid_thw = inputs['image_grid_thw']
|
||
input_ids = inputs['input_ids']
|
||
attention_mask = inputs['attention_mask']
|
||
cu_seqlens = torch.repeat_interleave(
|
||
image_grid_thw[:, 1] * image_grid_thw[:, 2],
|
||
image_grid_thw[:, 0]).cumsum(dim=0, dtype=torch.int32)
|
||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||
seq_length = image.shape[0]
|
||
attention_mask_vit = torch.zeros([1, seq_length, seq_length],
|
||
device=image.device,
|
||
dtype=torch.bool)
|
||
for i in range(1, len(cu_seqlens)):
|
||
attention_mask_vit[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
||
cu_seqlens[i - 1]:cu_seqlens[i]] = True
|
||
|
||
decoder_input_ids = None
|
||
post_prompt = None
|
||
pre_prompt = None
|
||
input_text = None
|
||
images_qwenvl = {
|
||
"image": image,
|
||
"input_ids": input_ids,
|
||
}
|
||
rotary_pos_emb = compute_rotary_pos_emb(
|
||
image_grid_thw, hf_config, VisionRotaryEmbedding).to("cuda")
|
||
other_vision_inputs['attention_mask_llm'] = attention_mask
|
||
other_vision_inputs['image_grid_thw'] = image_grid_thw
|
||
other_vision_inputs['attention_mask'] = attention_mask_vit
|
||
other_vision_inputs['rotary_pos_emb'] = rotary_pos_emb
|
||
return input_text, pre_prompt, post_prompt, images_qwenvl, decoder_input_ids, other_vision_inputs, other_decoder_inputs
|
||
elif 'nougat' in self.model_type:
|
||
image = self.processor(raw_image,
|
||
return_tensors="pt")['pixel_values']
|
||
# Nougat doesn't need text prompt (mBART use single token to start generation), just leave a dummy one here
|
||
if input_text is None:
|
||
input_text = "Question: which city is this? Answer:"
|
||
pre_prompt = input_text
|
||
post_prompt = None
|
||
|
||
elif 'cogvlm' in self.model_type:
|
||
image = self.transform(raw_image).unsqueeze(0)
|
||
if input_text is None:
|
||
input_text = " [INST] which city is this? [/INST] "
|
||
pre_prompt = input_text
|
||
post_prompt = None
|
||
|
||
elif 'phi-3-vision' in self.model_type:
|
||
pre_prompt = "<|user|>\n<|image_1|>\n"
|
||
if input_text is None:
|
||
input_text = "Which city is this?"
|
||
post_prompt = input_text + "<|end|>\n<|assistant|>\n"
|
||
prompt = pre_prompt + post_prompt
|
||
image = self.processor(text=prompt,
|
||
images=raw_image,
|
||
return_tensors="pt")
|
||
|
||
elif 'internvl' in self.model_type:
|
||
pre_prompt = "<|system|>\n你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。<|end|><|user|>\n<image>\n"
|
||
if input_text is None:
|
||
input_text = "Please describe the image shortly."
|
||
post_prompt = input_text + "<|end|><|assistant|>\n"
|
||
prompt = pre_prompt + post_prompt
|
||
image = self.processor(images=raw_image,
|
||
return_tensors='pt').pixel_values
|
||
|
||
elif self.model_type == "pix2struct":
|
||
if input_text is None:
|
||
input_text = ""
|
||
inputs = self.processor(
|
||
images=raw_image,
|
||
text=input_text,
|
||
return_tensors="pt",
|
||
)
|
||
image = inputs['flattened_patches']
|
||
image = image.expand(self.args.batch_size, -1, -1).contiguous()
|
||
pre_prompt = ""
|
||
post_prompt = None
|
||
|
||
elif self.model_type == "neva":
|
||
image = self.transform(raw_image).unsqueeze(0)
|
||
|
||
if input_text is None:
|
||
input_text = "Hi! What is in this image?"
|
||
|
||
pre_prompt = "<extra_id_0>System\n\n<extra_id_1>User\n"
|
||
post_prompt = f"\n{input_text}\n<extra_id_1>Assistant\n"
|
||
|
||
elif self.model_type == "video-neva":
|
||
|
||
image = self.video_preprocess(
|
||
raw_image) # shape (1, num_frames, 3, H, W)
|
||
|
||
if input_text is None:
|
||
input_text = "Hi! What is in this video?"
|
||
|
||
# SteerLM prompt template
|
||
pre_prompt = """<extra_id_0>System\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n<extra_id_1>User"""
|
||
post_prompt = f"\n{input_text}\n<extra_id_1>Assistant\n<extra_id_2>quality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4\n" ""
|
||
|
||
elif self.model_type == "llava_next":
|
||
if self.llm_name == "mistralai/Mistral-7B-Instruct-v0.2":
|
||
pre_prompt = "[INST] "
|
||
if input_text is None:
|
||
input_text = "Question: which city is this? Answer:"
|
||
post_prompt = f"\n{input_text} [/INST]"
|
||
prompt = pre_prompt + post_prompt
|
||
|
||
elif self.llm_name == "NousResearch/Nous-Hermes-2-Yi-34B":
|
||
pre_prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n"
|
||
if input_text is None:
|
||
input_text = "Question: which city is this? Answer:"
|
||
post_prompt = f"\n{input_text}<|im_end|><|im_start|>assistant\n"
|
||
prompt = pre_prompt + post_prompt
|
||
|
||
else:
|
||
raise Exception(
|
||
f"Prompt template for {self.llm_name} for not included currently"
|
||
)
|
||
|
||
image = self.processor(text=prompt,
|
||
images=raw_image,
|
||
return_tensors="pt")
|
||
|
||
elif self.model_type in ['llava', 'vila', 'fuyu', 'kosmos-2']:
|
||
# LLaVA and VILA
|
||
if self.model_type == "llava":
|
||
pre_prompt = "USER:\n"
|
||
if input_text is None:
|
||
input_text = "Question: which city is this? Answer:"
|
||
elif self.model_type == "vila":
|
||
pre_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "
|
||
if input_text is None:
|
||
input_text = "<image>\n Please elaborate what you see in the images?"
|
||
elif self.model_type == 'fuyu':
|
||
pre_prompt = "Describe this image:"
|
||
if input_text is None:
|
||
input_text = "Answer the following VQAv2 question based on the image: How many people are in the image?\n"
|
||
elif self.model_type == "kosmos-2":
|
||
pre_prompt = ""
|
||
if input_text is None:
|
||
input_text = "<grounding>An image of"
|
||
|
||
if self.model_type not in ['fuyu', 'kosmos-2']:
|
||
post_prompt = input_text + " ASSISTANT:"
|
||
else:
|
||
post_prompt = None
|
||
|
||
if self.model_type == "vila":
|
||
if not isinstance(raw_image, list):
|
||
raw_image = [raw_image]
|
||
image = self.processor(raw_image)
|
||
else:
|
||
if self.model_type in ['fuyu', 'kosmos-2']:
|
||
image = self.processor(text=input_text,
|
||
images=raw_image,
|
||
return_tensors='pt')
|
||
else:
|
||
image = self.processor(text=input_text,
|
||
images=raw_image,
|
||
return_tensors="pt")['pixel_values']
|
||
|
||
elif self.model_type in ['mllama']:
|
||
if raw_image is not None:
|
||
inputs = self.processor(images=raw_image,
|
||
text=input_text,
|
||
return_tensors="pt")
|
||
other_vision_inputs = {
|
||
"aspect_ratio_ids":
|
||
inputs["aspect_ratio_ids"].to(self.device).expand(
|
||
self.args.batch_size, -1).contiguous(),
|
||
"aspect_ratio_mask":
|
||
inputs["aspect_ratio_mask"].to(self.device).expand(
|
||
self.args.batch_size, -1, -1).contiguous(),
|
||
}
|
||
other_decoder_inputs = {
|
||
"cross_attention_mask":
|
||
inputs["cross_attention_mask"].to(self.device).expand(
|
||
self.args.batch_size, -1, -1, -1).contiguous(),
|
||
}
|
||
pre_prompt = input_text
|
||
post_prompt = None
|
||
image = inputs["pixel_values"]
|
||
else:
|
||
pre_prompt = input_text
|
||
post_prompt = None
|
||
image = None
|
||
logger.warning(
|
||
"image_path is None. Will not pass image as input, skipping the vision encoder."
|
||
)
|
||
image = None
|
||
elif self.model_type in ['llava_onevision']:
|
||
pre_prompt = "<|im_start|>user "
|
||
if input_text is None:
|
||
input_text = "Question: which city is this? Answer:" if self.args.video_path is None else "Why is this video funny?"
|
||
post_prompt = f"\n{input_text}<|im_end|><|im_start|>assistant\n"
|
||
prompt = pre_prompt + post_prompt
|
||
|
||
processor = AutoProcessor.from_pretrained(self.args.hf_model_dir)
|
||
if self.args.video_path is None:
|
||
image = processor(images=raw_image,
|
||
text=prompt,
|
||
return_tensors="pt")
|
||
else:
|
||
image = processor(videos=raw_image,
|
||
text=prompt,
|
||
return_tensors="pt")
|
||
|
||
# Repeat inputs to match batch size
|
||
pre_prompt = [pre_prompt] * self.args.batch_size
|
||
post_prompt = [post_prompt] * self.args.batch_size
|
||
if self.model_type not in [
|
||
'fuyu', 'pix2struct', 'kosmos-2', 'vila', 'phi-3-vision',
|
||
'llava_next', 'internvl', 'llava_onevision'
|
||
]:
|
||
if image is not None:
|
||
if image.dim() == 5:
|
||
image = image.expand(self.args.batch_size, -1, -1, -1,
|
||
-1).contiguous()
|
||
elif image.dim() == 6:
|
||
image = image.expand(self.args.batch_size, -1, -1, -1, -1,
|
||
-1).contiguous()
|
||
else:
|
||
image = image.expand(self.args.batch_size, -1, -1,
|
||
-1).contiguous()
|
||
if image is not None:
|
||
image = image.to(self.device)
|
||
# Generate decoder_input_ids for enc-dec models
|
||
# Custom prompts can be added as:
|
||
# decoder_input_ids = model.tokenizer(decoder_prompt).input_ids
|
||
if self.decoder_llm:
|
||
decoder_input_ids = None
|
||
else:
|
||
config = AutoConfig.from_pretrained(self.args.hf_model_dir)
|
||
if "blip2" in self.model_type:
|
||
decoder_start_id = config.text_config.decoder_start_token_id # T5
|
||
elif "nougat" in self.model_type:
|
||
decoder_start_id = config.decoder.bos_token_id # Nougat
|
||
else:
|
||
decoder_start_id = config.decoder_start_token_id
|
||
|
||
decoder_input_ids = torch.IntTensor([[decoder_start_id]])
|
||
decoder_input_ids = decoder_input_ids.repeat(
|
||
(self.args.batch_size, 1))
|
||
|
||
return input_text, pre_prompt, post_prompt, image, decoder_input_ids, other_vision_inputs, other_decoder_inputs
|
||
|
||
def run(self, input_text, input_image, max_new_tokens):
|
||
input_text, pre_prompt, post_prompt, processed_image, decoder_input_ids, other_vision_inputs, other_decoder_inputs = self.setup_inputs(
|
||
input_text, input_image)
|
||
output_text = self.generate(pre_prompt,
|
||
post_prompt,
|
||
processed_image,
|
||
decoder_input_ids,
|
||
max_new_tokens,
|
||
warmup=False,
|
||
other_vision_inputs=other_vision_inputs,
|
||
other_decoder_inputs=other_decoder_inputs)
|
||
|
||
return input_text, output_text
|