mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* move some models to examples/models/contrib Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * update the document Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * remove arctic, blip2, cogvlm, dbrx from qa test list Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * remove tests of dit, mmdit and stdit from qa test Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * remove grok, jais, sdxl, skywork, smaug from qa test list Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * re-organize the glm examples Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * fix issues after running pre-commit Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * fix some typo in glm_4_9b readme Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * fix bug Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> --------- Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
998 lines
40 KiB
Python
998 lines
40 KiB
Python
# Copyright 2024 HPC-AI Technology Inc.
|
||
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
# SPDX-License-Identifier: Apache-2.0
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
# reference: https://github.com/hpcaitech/Open-Sora/blob/main/opensora/utils/ckpt_utils.py
|
||
|
||
import html
|
||
import json
|
||
import os
|
||
import random
|
||
import re
|
||
import sys
|
||
import urllib.parse as ul
|
||
|
||
import ftfy
|
||
import numpy as np
|
||
import pandas as pd
|
||
import requests
|
||
import torch
|
||
import video_transforms
|
||
from aspect import get_image_size, get_num_frames
|
||
from bs4 import BeautifulSoup
|
||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||
from PIL import Image
|
||
from safetensors.torch import load_file
|
||
from torchvision import transforms
|
||
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
|
||
from torchvision.io import read_video, write_video
|
||
from torchvision.utils import save_image
|
||
|
||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||
|
||
URL_REGEX = re.compile(
|
||
r"^(?:http|ftp)s?://" # http:// or https://
|
||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
|
||
r"localhost|" # localhost...
|
||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
|
||
r"(?::\d+)?" # optional port
|
||
r"(?:/?|[/?]\S+)$",
|
||
re.IGNORECASE,
|
||
)
|
||
|
||
HF_ENDPOINT = os.environ.get("HF_ENDPOINT")
|
||
if HF_ENDPOINT is None:
|
||
HF_ENDPOINT = "https://huggingface.co"
|
||
PRETRAINED_MODELS = {
|
||
"DiT-XL-2-512x512.pt":
|
||
"https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt",
|
||
"DiT-XL-2-256x256.pt":
|
||
"https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt",
|
||
"Latte-XL-2-256x256-ucf101.pt":
|
||
HF_ENDPOINT + "/maxin-cn/Latte/resolve/main/ucf101.pt",
|
||
"PixArt-XL-2-256x256.pth":
|
||
HF_ENDPOINT +
|
||
"/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-256x256.pth",
|
||
"PixArt-XL-2-SAM-256x256.pth":
|
||
HF_ENDPOINT +
|
||
"/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-SAM-256x256.pth",
|
||
"PixArt-XL-2-512x512.pth":
|
||
HF_ENDPOINT +
|
||
"/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-512x512.pth",
|
||
"PixArt-XL-2-1024-MS.pth":
|
||
HF_ENDPOINT +
|
||
"/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth",
|
||
"OpenSora-v1-16x256x256.pth":
|
||
HF_ENDPOINT +
|
||
"/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth",
|
||
"OpenSora-v1-HQ-16x256x256.pth":
|
||
HF_ENDPOINT +
|
||
"/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth",
|
||
"OpenSora-v1-HQ-16x512x512.pth":
|
||
HF_ENDPOINT +
|
||
"/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth",
|
||
"PixArt-Sigma-XL-2-256x256.pth":
|
||
HF_ENDPOINT +
|
||
"/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
|
||
"PixArt-Sigma-XL-2-512-MS.pth":
|
||
HF_ENDPOINT +
|
||
"/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
|
||
"PixArt-Sigma-XL-2-1024-MS.pth":
|
||
HF_ENDPOINT +
|
||
"/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
|
||
"PixArt-Sigma-XL-2-2K-MS.pth":
|
||
HF_ENDPOINT +
|
||
"/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
|
||
}
|
||
|
||
|
||
def is_img(path):
|
||
ext = os.path.splitext(path)[-1].lower()
|
||
return ext in IMG_EXTENSIONS
|
||
|
||
|
||
def is_vid(path):
|
||
ext = os.path.splitext(path)[-1].lower()
|
||
return ext in VID_EXTENSIONS
|
||
|
||
|
||
def is_url(url):
|
||
return re.match(URL_REGEX, url) is not None
|
||
|
||
|
||
def read_file(input_path):
|
||
if input_path.endswith(".csv"):
|
||
return pd.read_csv(input_path)
|
||
elif input_path.endswith(".parquet"):
|
||
return pd.read_parquet(input_path)
|
||
else:
|
||
raise NotImplementedError(f"Unsupported file format: {input_path}")
|
||
|
||
|
||
def download_url(input_path):
|
||
output_dir = "cache"
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
base_name = os.path.basename(input_path)
|
||
output_path = os.path.join(output_dir, base_name)
|
||
try:
|
||
img_data = requests.get(input_path).content
|
||
except requests.exceptions.RequestException as e:
|
||
print(f"Error downloading URL {input_path}: {e}")
|
||
return None
|
||
with open(output_path, "wb") as handler:
|
||
handler.write(img_data)
|
||
print(f"URL {input_path} downloaded to {output_path}")
|
||
return output_path
|
||
|
||
|
||
def reparameter(ckpt, name=None, model=None):
|
||
model_name = name
|
||
name = os.path.basename(name)
|
||
if not torch.distributed.is_initialized() or torch.distributed.get_rank(
|
||
) == 0:
|
||
print("loading pretrained model: %s", model_name)
|
||
if name in ["DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"]:
|
||
ckpt["x_embedder.proj.weight"] = ckpt[
|
||
"x_embedder.proj.weight"].unsqueeze(2)
|
||
del ckpt["pos_embed"]
|
||
if name in ["Latte-XL-2-256x256-ucf101.pt"]:
|
||
ckpt = ckpt["ema"]
|
||
ckpt["x_embedder.proj.weight"] = ckpt[
|
||
"x_embedder.proj.weight"].unsqueeze(2)
|
||
del ckpt["pos_embed"]
|
||
del ckpt["temp_embed"]
|
||
if name in [
|
||
"PixArt-XL-2-256x256.pth",
|
||
"PixArt-XL-2-SAM-256x256.pth",
|
||
"PixArt-XL-2-512x512.pth",
|
||
"PixArt-XL-2-1024-MS.pth",
|
||
"PixArt-Sigma-XL-2-256x256.pth",
|
||
"PixArt-Sigma-XL-2-512-MS.pth",
|
||
"PixArt-Sigma-XL-2-1024-MS.pth",
|
||
"PixArt-Sigma-XL-2-2K-MS.pth",
|
||
]:
|
||
ckpt = ckpt["state_dict"]
|
||
ckpt["x_embedder.proj.weight"] = ckpt[
|
||
"x_embedder.proj.weight"].unsqueeze(2)
|
||
if "pos_embed" in ckpt:
|
||
del ckpt["pos_embed"]
|
||
if name in [
|
||
"PixArt-1B-2.pth",
|
||
]:
|
||
ckpt = ckpt["state_dict"]
|
||
if "pos_embed" in ckpt:
|
||
del ckpt["pos_embed"]
|
||
# no need pos_embed
|
||
if "pos_embed_temporal" in ckpt:
|
||
del ckpt["pos_embed_temporal"]
|
||
if "pos_embed" in ckpt:
|
||
del ckpt["pos_embed"]
|
||
# different text length
|
||
if "y_embedder.y_embedding" in ckpt:
|
||
if ckpt["y_embedder.y_embedding"].shape[
|
||
0] < model.y_embedder.y_embedding.shape[0]:
|
||
print(
|
||
"Extend y_embedding from %s to %s",
|
||
ckpt["y_embedder.y_embedding"].shape[0],
|
||
model.y_embedder.y_embedding.shape[0],
|
||
)
|
||
additional_length = model.y_embedder.y_embedding.shape[0] - ckpt[
|
||
"y_embedder.y_embedding"].shape[0]
|
||
new_y_embedding = torch.zeros(additional_length,
|
||
model.y_embedder.y_embedding.shape[1])
|
||
new_y_embedding[:] = ckpt["y_embedder.y_embedding"][-1]
|
||
ckpt["y_embedder.y_embedding"] = torch.cat(
|
||
[ckpt["y_embedder.y_embedding"], new_y_embedding], dim=0)
|
||
elif ckpt["y_embedder.y_embedding"].shape[
|
||
0] > model.y_embedder.y_embedding.shape[0]:
|
||
print(
|
||
"Shrink y_embedding from %s to %s",
|
||
ckpt["y_embedder.y_embedding"].shape[0],
|
||
model.y_embedder.y_embedding.shape[0],
|
||
)
|
||
ckpt["y_embedder.y_embedding"] = ckpt[
|
||
"y_embedder.y_embedding"][:model.y_embedder.y_embedding.
|
||
shape[0]]
|
||
# stdit3 special case
|
||
if type(model).__name__ == "STDiT3" and "PixArt-Sigma" in name:
|
||
ckpt_keys = list(ckpt.keys())
|
||
for key in ckpt_keys:
|
||
if "blocks." in key:
|
||
ckpt[key.replace("blocks.", "spatial_blocks.")] = ckpt[key]
|
||
del ckpt[key]
|
||
return ckpt
|
||
|
||
|
||
def download_model(model_name=None, local_path=None, url=None):
|
||
"""
|
||
Downloads a pre-trained DiT model from the web.
|
||
"""
|
||
if model_name is not None:
|
||
assert model_name in PRETRAINED_MODELS
|
||
local_path = f"pretrained_models/{model_name}"
|
||
web_path = PRETRAINED_MODELS[model_name]
|
||
else:
|
||
assert local_path is not None
|
||
assert url is not None
|
||
web_path = url
|
||
if not os.path.isfile(local_path):
|
||
os.makedirs("pretrained_models", exist_ok=True)
|
||
dir_name = os.path.dirname(local_path)
|
||
file_name = os.path.basename(local_path)
|
||
download_url(web_path, dir_name, file_name)
|
||
model = torch.load(local_path, map_location=lambda storage, loc: storage)
|
||
return model
|
||
|
||
|
||
def find_model(model_name, model=None):
|
||
"""
|
||
Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
|
||
"""
|
||
if model_name in PRETRAINED_MODELS: # Find/download our pre-trained DiT checkpoints
|
||
model_ckpt = download_model(model_name)
|
||
model_ckpt = reparameter(model_ckpt, model_name, model=model)
|
||
else: # Load a custom DiT checkpoint:
|
||
assert os.path.isfile(
|
||
model_name), f"Could not find DiT checkpoint at {model_name}"
|
||
model_ckpt = torch.load(model_name,
|
||
map_location=lambda storage, loc: storage)
|
||
model_ckpt = reparameter(model_ckpt, model_name, model=model)
|
||
return model_ckpt
|
||
|
||
|
||
def load_from_sharded_state_dict(model,
|
||
ckpt_path,
|
||
model_name="model",
|
||
strict=False):
|
||
ckpt_io = GeneralCheckpointIO()
|
||
ckpt_io.load_model(model,
|
||
os.path.join(ckpt_path, model_name),
|
||
strict=strict)
|
||
|
||
|
||
def load_checkpoint(model,
|
||
ckpt_path,
|
||
save_as_pt=False,
|
||
model_name="model",
|
||
strict=False):
|
||
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
|
||
state_dict = find_model(ckpt_path, model=model)
|
||
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
|
||
strict=strict)
|
||
print("Missing keys: %s", missing_keys)
|
||
print("Unexpected keys: %s", unexpected_keys)
|
||
elif ckpt_path.endswith(".safetensors"):
|
||
state_dict = load_file(ckpt_path)
|
||
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
|
||
strict=False)
|
||
print(f"Missing keys: {missing_keys}")
|
||
print(f"Unexpected keys: {unexpected_keys}")
|
||
elif os.path.isdir(ckpt_path):
|
||
load_from_sharded_state_dict(model,
|
||
ckpt_path,
|
||
model_name,
|
||
strict=strict)
|
||
print("Model checkpoint loaded from %s", ckpt_path)
|
||
if save_as_pt:
|
||
save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt")
|
||
torch.save(model.state_dict(), save_path)
|
||
print("Model checkpoint saved to %s", save_path)
|
||
else:
|
||
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
|
||
|
||
|
||
def print_progress_bar(iteration, total, length=40):
|
||
iteration += 1
|
||
filled_length = int(length * iteration // total)
|
||
bar = '█' * filled_length + '-' * (length - filled_length)
|
||
sys.stdout.write(f'\rDenoising steps: |{bar}| {iteration}/{total}')
|
||
sys.stdout.flush()
|
||
|
||
|
||
class PromptProcessor():
|
||
|
||
@staticmethod
|
||
def load_prompts(prompt_path, start_idx=None, end_idx=None):
|
||
with open(prompt_path, "r") as f:
|
||
prompts = [line.strip() for line in f.readlines()]
|
||
prompts = prompts[start_idx:end_idx]
|
||
return prompts
|
||
|
||
@staticmethod
|
||
def extract_json_from_prompts(prompts, reference, mask_strategy):
|
||
ret_prompts = []
|
||
for i, prompt in enumerate(prompts):
|
||
parts = re.split(r"(?=[{])", prompt)
|
||
assert len(parts) <= 2, f"Invalid prompt: {prompt}"
|
||
ret_prompts.append(parts[0])
|
||
if len(parts) > 1:
|
||
additional_info = json.loads(parts[1])
|
||
for key in additional_info:
|
||
assert key in ["reference_path",
|
||
"mask_strategy"], f"Invalid key: {key}"
|
||
if key == "reference_path":
|
||
reference[i] = additional_info[key]
|
||
elif key == "mask_strategy":
|
||
mask_strategy[i] = additional_info[key]
|
||
return ret_prompts, reference, mask_strategy
|
||
|
||
@staticmethod
|
||
def split_prompt(prompt_text):
|
||
if prompt_text.startswith("|0|"):
|
||
# this is for prompts which look like
|
||
# |0| a beautiful day |1| a sunny day |2| a rainy day
|
||
# we want to parse it into a list of prompts with the loop index
|
||
prompt_list = prompt_text.split("|")[1:]
|
||
text_list = []
|
||
loop_idx = []
|
||
for i in range(0, len(prompt_list), 2):
|
||
start_loop = int(prompt_list[i])
|
||
text = prompt_list[i + 1].strip()
|
||
text_list.append(text)
|
||
loop_idx.append(start_loop)
|
||
return text_list, loop_idx
|
||
else:
|
||
return [prompt_text], None
|
||
|
||
@staticmethod
|
||
def merge_prompt(text_list, loop_idx_list=None):
|
||
if loop_idx_list is None:
|
||
return text_list[0]
|
||
else:
|
||
prompt = ""
|
||
for i, text in enumerate(text_list):
|
||
prompt += f"|{loop_idx_list[i]}|{text}"
|
||
return prompt
|
||
|
||
@staticmethod
|
||
def extract_prompts_loop(prompts, num_loop):
|
||
ret_prompts = []
|
||
for prompt in prompts:
|
||
if prompt.startswith("|0|"):
|
||
prompt_list = prompt.split("|")[1:]
|
||
text_list = []
|
||
for i in range(0, len(prompt_list), 2):
|
||
start_loop = int(prompt_list[i])
|
||
text = prompt_list[i + 1]
|
||
end_loop = int(prompt_list[i + 2]) if i + 2 < len(
|
||
prompt_list) else num_loop + 1
|
||
text_list.extend([text] * (end_loop - start_loop))
|
||
prompt = text_list[num_loop]
|
||
ret_prompts.append(prompt)
|
||
return ret_prompts
|
||
|
||
@staticmethod
|
||
def append_score_to_prompts(prompts,
|
||
aes=None,
|
||
flow=None,
|
||
camera_motion=None):
|
||
new_prompts = []
|
||
for prompt in prompts:
|
||
new_prompt = prompt
|
||
if aes is not None and "aesthetic score:" not in prompt:
|
||
new_prompt = f"{new_prompt} aesthetic score: {aes:.1f}."
|
||
if flow is not None and "motion score:" not in prompt:
|
||
new_prompt = f"{new_prompt} motion score: {flow:.1f}."
|
||
if camera_motion is not None and "camera motion:" not in prompt:
|
||
new_prompt = f"{new_prompt} camera motion: {camera_motion}."
|
||
new_prompts.append(new_prompt)
|
||
return new_prompts
|
||
|
||
@classmethod
|
||
def basic_clean(cls, text):
|
||
text = ftfy.fix_text(text)
|
||
text = html.unescape(html.unescape(text))
|
||
return text.strip()
|
||
|
||
@classmethod
|
||
def clean_caption(cls, caption):
|
||
BAD_PUNCT_REGEX = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" +
|
||
"\]" + "\[" + "\}" + "\{" + "\|" + "\\" +
|
||
"\/" + "\*" + r"]{1,}") # noqa
|
||
caption = str(caption)
|
||
caption = ul.unquote_plus(caption)
|
||
caption = caption.strip().lower()
|
||
caption = re.sub("<person>", "person", caption)
|
||
# urls:
|
||
caption = re.sub(
|
||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||
"",
|
||
caption,
|
||
) # regex for urls
|
||
caption = re.sub(
|
||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||
"",
|
||
caption,
|
||
) # regex for urls
|
||
# html:
|
||
caption = BeautifulSoup(caption, features="html.parser").text
|
||
# @<nickname>
|
||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||
# 31C0—31EF CJK Strokes
|
||
# 31F0—31FF Katakana Phonetic Extensions
|
||
# 3200—32FF Enclosed CJK Letters and Months
|
||
# 3300—33FF CJK Compatibility
|
||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||
# 4E00—9FFF CJK Unified Ideographs
|
||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||
#######################################################
|
||
|
||
# все виды тире / all types of dash --> "-"
|
||
caption = re.sub(
|
||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||
"-",
|
||
caption,
|
||
)
|
||
# кавычки к одному стандарту
|
||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||
caption = re.sub(r"[‘’]", "'", caption)
|
||
# "
|
||
caption = re.sub(r""?", "", caption)
|
||
# &
|
||
caption = re.sub(r"&", "", caption)
|
||
# ip addresses:
|
||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||
# article ids:
|
||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||
# \n
|
||
caption = re.sub(r"\\n", " ", caption)
|
||
# "#123"
|
||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||
# "#12345.."
|
||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||
# "123456.."
|
||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||
# filenames:
|
||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)",
|
||
"", caption)
|
||
#
|
||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||
caption = re.sub(BAD_PUNCT_REGEX, r" ",
|
||
caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||
regex2 = re.compile(r"(?:\-|\_)")
|
||
if len(re.findall(regex2, caption)) > 3:
|
||
caption = re.sub(regex2, " ", caption)
|
||
caption = cls.basic_clean(caption)
|
||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||
caption = re.sub(
|
||
r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "",
|
||
caption)
|
||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ",
|
||
caption) # j2d1a2a...
|
||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||
caption = re.sub(r"\s+", " ", caption)
|
||
caption.strip()
|
||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||
caption = re.sub(r"^\.\S+$", "", caption)
|
||
return caption.strip()
|
||
|
||
@classmethod
|
||
def text_preprocessing(cls, text, use_text_preprocessing: bool = True):
|
||
if use_text_preprocessing:
|
||
# The exact text cleaning as was in the training stage:
|
||
text = cls.clean_caption(text)
|
||
text = cls.clean_caption(text)
|
||
return text
|
||
else:
|
||
return text.lower().strip()
|
||
|
||
|
||
class VideoProcessor():
|
||
MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"]
|
||
|
||
@staticmethod
|
||
def get_image_size(resolution, ar_ratio):
|
||
return get_image_size(resolution=resolution, ar_ratio=ar_ratio)
|
||
|
||
@staticmethod
|
||
def get_num_frames(num_frames):
|
||
return get_num_frames(num_frames=num_frames)
|
||
|
||
@staticmethod
|
||
def get_latent_size(vae, input_size):
|
||
return vae.get_latent_size(input_size=input_size)
|
||
|
||
@staticmethod
|
||
def center_crop_arr(pil_image, image_size):
|
||
"""
|
||
Center cropping implementation from ADM.
|
||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||
"""
|
||
while min(*pil_image.size) >= 2 * image_size:
|
||
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size),
|
||
resample=Image.BOX)
|
||
|
||
scale = image_size / min(*pil_image.size)
|
||
pil_image = pil_image.resize(tuple(
|
||
round(x * scale) for x in pil_image.size),
|
||
resample=Image.BICUBIC)
|
||
|
||
arr = np.array(pil_image)
|
||
crop_y = (arr.shape[0] - image_size) // 2
|
||
crop_x = (arr.shape[1] - image_size) // 2
|
||
return Image.fromarray(arr[crop_y:crop_y + image_size,
|
||
crop_x:crop_x + image_size])
|
||
|
||
@staticmethod
|
||
def resize_crop_to_fill(pil_image, image_size):
|
||
w, h = pil_image.size # PIL is (W, H)
|
||
th, tw = image_size
|
||
rh, rw = th / h, tw / w
|
||
if rh > rw:
|
||
sh, sw = th, round(w * rh)
|
||
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
||
i = 0
|
||
j = int(round((sw - tw) / 2.0))
|
||
else:
|
||
sh, sw = round(h * rw), tw
|
||
image = pil_image.resize((sw, sh), Image.BICUBIC)
|
||
i = int(round((sh - th) / 2.0))
|
||
j = 0
|
||
arr = np.array(image)
|
||
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
|
||
return Image.fromarray(arr[i:i + th, j:j + tw])
|
||
|
||
@classmethod
|
||
def get_transforms_image(cls, name="center", image_size=(256, 256)):
|
||
if name is None:
|
||
return None
|
||
elif name == "center":
|
||
assert image_size[0] == image_size[
|
||
1], "Image size must be square for center crop"
|
||
transform = transforms.Compose([
|
||
transforms.Lambda(lambda pil_image: cls.center_crop_arr(
|
||
pil_image, image_size[0])),
|
||
# transforms.RandomHorizontalFlip(),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
||
std=[0.5, 0.5, 0.5],
|
||
inplace=True),
|
||
])
|
||
elif name == "resize_crop":
|
||
transform = transforms.Compose([
|
||
transforms.Lambda(lambda pil_image: cls.resize_crop_to_fill(
|
||
pil_image, image_size)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
||
std=[0.5, 0.5, 0.5],
|
||
inplace=True),
|
||
])
|
||
else:
|
||
raise NotImplementedError(f"Transform {name} not implemented")
|
||
return transform
|
||
|
||
@classmethod
|
||
def get_transforms_video(cls, name="center", image_size=(256, 256)):
|
||
if name is None:
|
||
return None
|
||
elif name == "center":
|
||
assert image_size[0] == image_size[
|
||
1], "image_size must be square for center crop"
|
||
transform_video = transforms.Compose([
|
||
video_transforms.ToTensorVideo(), # TCHW
|
||
# video_transforms.RandomHorizontalFlipVideo(),
|
||
video_transforms.UCFCenterCropVideo(image_size[0]),
|
||
transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
||
std=[0.5, 0.5, 0.5],
|
||
inplace=True),
|
||
])
|
||
elif name == "resize_crop":
|
||
transform_video = transforms.Compose([
|
||
video_transforms.ToTensorVideo(), # TCHW
|
||
video_transforms.ResizeCrop(image_size),
|
||
transforms.Normalize(mean=[0.5, 0.5, 0.5],
|
||
std=[0.5, 0.5, 0.5],
|
||
inplace=True),
|
||
])
|
||
else:
|
||
raise NotImplementedError(f"Transform {name} not implemented")
|
||
return transform_video
|
||
|
||
@classmethod
|
||
def read_image_from_path(cls,
|
||
path,
|
||
transform=None,
|
||
transform_name="center",
|
||
num_frames=1,
|
||
image_size=(256, 256)):
|
||
image = pil_loader(path)
|
||
if transform is None:
|
||
transform = cls.get_transforms_image(image_size=image_size,
|
||
name=transform_name)
|
||
image = transform(image)
|
||
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
|
||
video = video.permute(1, 0, 2, 3)
|
||
return video
|
||
|
||
@classmethod
|
||
def read_video_from_path(cls,
|
||
path,
|
||
transform=None,
|
||
transform_name="center",
|
||
image_size=(256, 256)):
|
||
vframes, aframes, info = read_video(filename=path,
|
||
pts_unit="sec",
|
||
output_format="TCHW")
|
||
if transform is None:
|
||
transform = cls.get_transforms_video(image_size=image_size,
|
||
name=transform_name)
|
||
video = transform(vframes) # T C H W
|
||
video = video.permute(1, 0, 2, 3)
|
||
return video
|
||
|
||
@classmethod
|
||
def read_from_path(cls, path, image_size, transform_name="center"):
|
||
if is_url(path):
|
||
path = download_url(path)
|
||
ext = os.path.splitext(path)[-1].lower()
|
||
if ext.lower() in VID_EXTENSIONS:
|
||
return cls.read_video_from_path(path,
|
||
image_size=image_size,
|
||
transform_name=transform_name)
|
||
else:
|
||
assert ext.lower(
|
||
) in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
|
||
return cls.read_image_from_path(path,
|
||
image_size=image_size,
|
||
transform_name=transform_name)
|
||
|
||
@classmethod
|
||
def collect_references_batch(cls, vae, reference_paths, image_size):
|
||
refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
|
||
for reference_path in reference_paths:
|
||
if reference_path == "":
|
||
refs_x.append([])
|
||
continue
|
||
ref_path = reference_path.split(";")
|
||
ref = []
|
||
for r_path in ref_path:
|
||
r = cls.read_from_path(r_path,
|
||
image_size,
|
||
transform_name="resize_crop")
|
||
r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
|
||
r_x = r_x.squeeze(0)
|
||
ref.append(r_x)
|
||
refs_x.append(ref)
|
||
return refs_x
|
||
|
||
@staticmethod
|
||
def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i,
|
||
condition_frame_length, condition_frame_edit):
|
||
ref_x = vae.encode(generated_video)
|
||
for j, refs in enumerate(refs_x):
|
||
if refs is None:
|
||
refs_x[j] = [ref_x[j]]
|
||
else:
|
||
refs.append(ref_x[j])
|
||
if mask_strategy[j] is None or mask_strategy[j] == "":
|
||
mask_strategy[j] = ""
|
||
else:
|
||
mask_strategy[j] += ";"
|
||
mask_strategy[
|
||
j] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}"
|
||
return refs_x, mask_strategy
|
||
|
||
@classmethod
|
||
def parse_mask_strategy(cls, mask_strategy):
|
||
mask_batch = []
|
||
if mask_strategy == "" or mask_strategy is None:
|
||
return mask_batch
|
||
mask_strategy = mask_strategy.split(";")
|
||
for mask in mask_strategy:
|
||
mask_group = mask.split(",")
|
||
num_group = len(mask_group)
|
||
assert num_group >= 1 and num_group <= 6, f"Invalid mask strategy: {mask}"
|
||
mask_group.extend(cls.MASK_DEFAULT[num_group:])
|
||
for i in range(5):
|
||
mask_group[i] = int(mask_group[i])
|
||
mask_group[5] = float(mask_group[5])
|
||
mask_batch.append(mask_group)
|
||
return mask_batch
|
||
|
||
@classmethod
|
||
def find_nearest_point(cls, value, point, max_value):
|
||
t = value // point
|
||
if value % point > point / 2 and t < max_value // point - 1:
|
||
t += 1
|
||
return t * point
|
||
|
||
@classmethod
|
||
def apply_mask_strategy(cls, z, refs_x, mask_strategys, loop_i, align=None):
|
||
masks = []
|
||
no_mask = True
|
||
for i, mask_strategy in enumerate(mask_strategys):
|
||
no_mask = False
|
||
mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
|
||
mask_strategy = cls.parse_mask_strategy(mask_strategy)
|
||
for mst in mask_strategy:
|
||
loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst
|
||
if loop_id != loop_i:
|
||
continue
|
||
ref = refs_x[i][m_id]
|
||
|
||
if m_ref_start < 0:
|
||
# ref: [C, T, H, W]
|
||
m_ref_start = ref.shape[1] + m_ref_start
|
||
if m_target_start < 0:
|
||
# z: [B, C, T, H, W]
|
||
m_target_start = z.shape[2] + m_target_start
|
||
if align is not None:
|
||
m_ref_start = cls.find_nearest_point(
|
||
m_ref_start, align, ref.shape[1])
|
||
m_target_start = cls.find_nearest_point(
|
||
m_target_start, align, z.shape[2])
|
||
m_length = min(m_length, z.shape[2] - m_target_start,
|
||
ref.shape[1] - m_ref_start)
|
||
z[i, :, m_target_start:m_target_start +
|
||
m_length] = ref[:, m_ref_start:m_ref_start + m_length]
|
||
mask[m_target_start:m_target_start + m_length] = edit_ratio
|
||
masks.append(mask)
|
||
if no_mask:
|
||
return None
|
||
masks = torch.stack(masks)
|
||
return masks
|
||
|
||
@staticmethod
|
||
def dframe_to_frame(num):
|
||
assert num % 5 == 0, f"Invalid num: {num}"
|
||
return num // 5 * 17
|
||
|
||
@staticmethod
|
||
def save_sample(x,
|
||
save_path=None,
|
||
fps=8,
|
||
normalize=True,
|
||
value_range=(-1, 1),
|
||
force_video=False,
|
||
verbose=True):
|
||
"""
|
||
Args:
|
||
x (Tensor): shape [C, T, H, W]
|
||
"""
|
||
assert x.ndim == 4
|
||
if not force_video and x.shape[1] == 1: # T = 1: save as image
|
||
save_path += ".png"
|
||
x = x.squeeze(1)
|
||
save_image([x],
|
||
save_path,
|
||
normalize=normalize,
|
||
value_range=value_range)
|
||
else:
|
||
save_path += ".mp4"
|
||
if normalize:
|
||
low, high = value_range
|
||
x.clamp_(min=low, max=high)
|
||
x.sub_(low).div_(max(high - low, 1e-5))
|
||
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to(
|
||
"cpu", torch.uint8)
|
||
write_video(save_path, x, fps=fps, video_codec="h264")
|
||
if verbose:
|
||
print(f"Saved to {save_path}")
|
||
return save_path
|
||
|
||
@staticmethod
|
||
def add_watermark(
|
||
input_video_path,
|
||
watermark_image_path="./assets/images/watermark/watermark.png",
|
||
output_video_path=None):
|
||
# execute this command in terminal with subprocess
|
||
# return if the process is successful
|
||
if output_video_path is None:
|
||
output_video_path = input_video_path.replace(
|
||
".mp4", "_watermark.mp4")
|
||
cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path}'
|
||
cmd += f'-filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}'
|
||
import subprocess
|
||
exit_code = subprocess.call(cmd, shell=True)
|
||
is_success = exit_code == 0
|
||
return is_success
|
||
|
||
|
||
class DataProcessor():
|
||
|
||
def __init__(self, text_encoder, vae):
|
||
self._text_encoder = text_encoder
|
||
self._vae = vae
|
||
|
||
@staticmethod
|
||
def set_random_seed(seed: int):
|
||
random.seed(seed)
|
||
np.random.seed(seed)
|
||
torch.manual_seed(seed)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.manual_seed_all(seed)
|
||
return seed
|
||
|
||
@staticmethod
|
||
def prepare_multi_resolution_info(model_type, batch_size, image_size,
|
||
num_frames, fps, device, dtype):
|
||
IMG_FPS = 120
|
||
ret = {}
|
||
if model_type in ["STDiT3", "OpenSora"]:
|
||
fps = fps if num_frames > 1 else IMG_FPS
|
||
fps = torch.tensor([fps], device=device,
|
||
dtype=dtype).repeat(batch_size)
|
||
height = torch.tensor([image_size[0]], device=device,
|
||
dtype=dtype).repeat(batch_size)
|
||
width = torch.tensor([image_size[1]], device=device,
|
||
dtype=dtype).repeat(batch_size)
|
||
num_frames = torch.tensor([num_frames], device=device,
|
||
dtype=dtype).repeat(batch_size)
|
||
ar = torch.tensor([image_size[0] / image_size[1]],
|
||
device=device,
|
||
dtype=dtype).repeat(batch_size)
|
||
ret = dict(height=height,
|
||
width=width,
|
||
num_frames=num_frames,
|
||
ar=ar,
|
||
fps=fps)
|
||
else:
|
||
raise NotImplementedError(f"Model type is {model_type}")
|
||
return ret
|
||
|
||
@staticmethod
|
||
def get_save_path_name(
|
||
save_dir,
|
||
sample_name=None, # prefix
|
||
sample_idx=None, # sample index
|
||
prompt=None, # used prompt
|
||
prompt_as_path=False, # use prompt as path
|
||
num_sample=1, # number of samples to generate for one prompt
|
||
k=None, # kth sample
|
||
):
|
||
if sample_name is None:
|
||
sample_name = "" if prompt_as_path else "sample"
|
||
sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx:04d}"
|
||
save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix}")
|
||
if num_sample != 1:
|
||
save_path = f"{save_path}-{k}"
|
||
return save_path
|
||
|
||
@staticmethod
|
||
def load_prompts(prompt_path, start_idx=None, end_idx=None):
|
||
return PromptProcessor.load_prompts(prompt_path=prompt_path,
|
||
start_idx=start_idx,
|
||
end_idx=end_idx)
|
||
|
||
@staticmethod
|
||
def extract_json_from_prompts(prompts, reference, mask_strategy):
|
||
return PromptProcessor.extract_json_from_prompts(
|
||
prompts=prompts, reference=reference, mask_strategy=mask_strategy)
|
||
|
||
@staticmethod
|
||
def split_prompt(prompt_text):
|
||
return PromptProcessor.split_prompt(prompt_text=prompt_text)
|
||
|
||
@staticmethod
|
||
def merge_prompt(text_list, loop_idx_list=None):
|
||
return PromptProcessor.merge_prompt(text_list=text_list,
|
||
loop_idx_list=loop_idx_list)
|
||
|
||
@staticmethod
|
||
def extract_prompts_loop(prompts, num_loop):
|
||
return PromptProcessor.extract_prompts_loop(prompts=prompts,
|
||
num_loop=num_loop)
|
||
|
||
@staticmethod
|
||
def append_score_to_prompts(prompts,
|
||
aes=None,
|
||
flow=None,
|
||
camera_motion=None):
|
||
return PromptProcessor.append_score_to_prompts(
|
||
prompts=prompts, aes=aes, flow=flow, camera_motion=camera_motion)
|
||
|
||
@staticmethod
|
||
def text_preprocessing(text, use_text_preprocessing: bool = True):
|
||
return PromptProcessor.text_preprocessing(
|
||
text=text, use_text_preprocessing=use_text_preprocessing)
|
||
|
||
@staticmethod
|
||
def get_image_size(resolution, ar_ratio):
|
||
return VideoProcessor.get_image_size(resolution=resolution,
|
||
ar_ratio=ar_ratio)
|
||
|
||
@staticmethod
|
||
def get_num_frames(num_frames):
|
||
return VideoProcessor.get_num_frames(num_frames=num_frames)
|
||
|
||
def get_latent_size(self, input_size):
|
||
return VideoProcessor.get_latent_size(self._vae, input_size=input_size)
|
||
|
||
def collect_references_batch(self, reference_paths, image_size):
|
||
return VideoProcessor.collect_references_batch(
|
||
self._vae, reference_paths=reference_paths, image_size=image_size)
|
||
|
||
def append_generated(self, generated_video, refs_x, mask_strategy, loop_i,
|
||
condition_frame_length, condition_frame_edit):
|
||
return VideoProcessor.append_generated(
|
||
self._vae,
|
||
generated_video=generated_video,
|
||
refs_x=refs_x,
|
||
mask_strategy=mask_strategy,
|
||
loop_i=loop_i,
|
||
condition_frame_length=condition_frame_length,
|
||
condition_frame_edit=condition_frame_edit)
|
||
|
||
@staticmethod
|
||
def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
|
||
return VideoProcessor.apply_mask_strategy(z=z,
|
||
refs_x=refs_x,
|
||
mask_strategys=mask_strategys,
|
||
loop_i=loop_i,
|
||
align=align)
|
||
|
||
@staticmethod
|
||
def dframe_to_frame(num):
|
||
return VideoProcessor.dframe_to_frame(num=num)
|
||
|
||
@staticmethod
|
||
def save_sample(x,
|
||
save_path=None,
|
||
fps=8,
|
||
normalize=True,
|
||
value_range=(-1, 1),
|
||
force_video=False,
|
||
verbose=True):
|
||
return VideoProcessor.save_sample(x=x,
|
||
save_path=save_path,
|
||
fps=fps,
|
||
normalize=normalize,
|
||
value_range=value_range,
|
||
force_video=force_video,
|
||
verbose=verbose)
|
||
|
||
@staticmethod
|
||
def add_watermark(
|
||
input_video_path,
|
||
watermark_image_path="./assets/images/watermark/watermark.png",
|
||
output_video_path=None):
|
||
return VideoProcessor.add_watermark(
|
||
input_video_path=input_video_path,
|
||
watermark_image_path=watermark_image_path,
|
||
output_video_path=output_video_path)
|