TensorRT-LLMs/examples/models/contrib/stdit/utils.py
bhsueh_NV 322ac565fc
chore: clean some ci of qa test (#3083)
* move some models to examples/models/contrib

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* update the document

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove arctic, blip2, cogvlm, dbrx from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove tests of dit, mmdit and stdit from qa test

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove grok, jais, sdxl, skywork, smaug from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* re-organize the glm examples

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix issues after running pre-commit

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix some typo in glm_4_9b readme

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix bug

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

---------

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2025-03-31 14:30:41 +08:00

998 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2024 HPC-AI Technology Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# reference: https://github.com/hpcaitech/Open-Sora/blob/main/opensora/utils/ckpt_utils.py
import html
import json
import os
import random
import re
import sys
import urllib.parse as ul
import ftfy
import numpy as np
import pandas as pd
import requests
import torch
import video_transforms
from aspect import get_image_size, get_num_frames
from bs4 import BeautifulSoup
from colossalai.checkpoint_io import GeneralCheckpointIO
from PIL import Image
from safetensors.torch import load_file
from torchvision import transforms
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader
from torchvision.io import read_video, write_video
from torchvision.utils import save_image
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
URL_REGEX = re.compile(
r"^(?:http|ftp)s?://" # http:// or https://
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # ...or ip
r"(?::\d+)?" # optional port
r"(?:/?|[/?]\S+)$",
re.IGNORECASE,
)
HF_ENDPOINT = os.environ.get("HF_ENDPOINT")
if HF_ENDPOINT is None:
HF_ENDPOINT = "https://huggingface.co"
PRETRAINED_MODELS = {
"DiT-XL-2-512x512.pt":
"https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt",
"DiT-XL-2-256x256.pt":
"https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt",
"Latte-XL-2-256x256-ucf101.pt":
HF_ENDPOINT + "/maxin-cn/Latte/resolve/main/ucf101.pt",
"PixArt-XL-2-256x256.pth":
HF_ENDPOINT +
"/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-256x256.pth",
"PixArt-XL-2-SAM-256x256.pth":
HF_ENDPOINT +
"/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-SAM-256x256.pth",
"PixArt-XL-2-512x512.pth":
HF_ENDPOINT +
"/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-512x512.pth",
"PixArt-XL-2-1024-MS.pth":
HF_ENDPOINT +
"/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth",
"OpenSora-v1-16x256x256.pth":
HF_ENDPOINT +
"/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-16x256x256.pth",
"OpenSora-v1-HQ-16x256x256.pth":
HF_ENDPOINT +
"/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x256x256.pth",
"OpenSora-v1-HQ-16x512x512.pth":
HF_ENDPOINT +
"/hpcai-tech/Open-Sora/resolve/main/OpenSora-v1-HQ-16x512x512.pth",
"PixArt-Sigma-XL-2-256x256.pth":
HF_ENDPOINT +
"/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-256x256.pth",
"PixArt-Sigma-XL-2-512-MS.pth":
HF_ENDPOINT +
"/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-512-MS.pth",
"PixArt-Sigma-XL-2-1024-MS.pth":
HF_ENDPOINT +
"/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-2K-MS.pth":
HF_ENDPOINT +
"/PixArt-alpha/PixArt-Sigma/resolve/main/PixArt-Sigma-XL-2-2K-MS.pth",
}
def is_img(path):
ext = os.path.splitext(path)[-1].lower()
return ext in IMG_EXTENSIONS
def is_vid(path):
ext = os.path.splitext(path)[-1].lower()
return ext in VID_EXTENSIONS
def is_url(url):
return re.match(URL_REGEX, url) is not None
def read_file(input_path):
if input_path.endswith(".csv"):
return pd.read_csv(input_path)
elif input_path.endswith(".parquet"):
return pd.read_parquet(input_path)
else:
raise NotImplementedError(f"Unsupported file format: {input_path}")
def download_url(input_path):
output_dir = "cache"
os.makedirs(output_dir, exist_ok=True)
base_name = os.path.basename(input_path)
output_path = os.path.join(output_dir, base_name)
try:
img_data = requests.get(input_path).content
except requests.exceptions.RequestException as e:
print(f"Error downloading URL {input_path}: {e}")
return None
with open(output_path, "wb") as handler:
handler.write(img_data)
print(f"URL {input_path} downloaded to {output_path}")
return output_path
def reparameter(ckpt, name=None, model=None):
model_name = name
name = os.path.basename(name)
if not torch.distributed.is_initialized() or torch.distributed.get_rank(
) == 0:
print("loading pretrained model: %s", model_name)
if name in ["DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt"]:
ckpt["x_embedder.proj.weight"] = ckpt[
"x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
if name in ["Latte-XL-2-256x256-ucf101.pt"]:
ckpt = ckpt["ema"]
ckpt["x_embedder.proj.weight"] = ckpt[
"x_embedder.proj.weight"].unsqueeze(2)
del ckpt["pos_embed"]
del ckpt["temp_embed"]
if name in [
"PixArt-XL-2-256x256.pth",
"PixArt-XL-2-SAM-256x256.pth",
"PixArt-XL-2-512x512.pth",
"PixArt-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-256x256.pth",
"PixArt-Sigma-XL-2-512-MS.pth",
"PixArt-Sigma-XL-2-1024-MS.pth",
"PixArt-Sigma-XL-2-2K-MS.pth",
]:
ckpt = ckpt["state_dict"]
ckpt["x_embedder.proj.weight"] = ckpt[
"x_embedder.proj.weight"].unsqueeze(2)
if "pos_embed" in ckpt:
del ckpt["pos_embed"]
if name in [
"PixArt-1B-2.pth",
]:
ckpt = ckpt["state_dict"]
if "pos_embed" in ckpt:
del ckpt["pos_embed"]
# no need pos_embed
if "pos_embed_temporal" in ckpt:
del ckpt["pos_embed_temporal"]
if "pos_embed" in ckpt:
del ckpt["pos_embed"]
# different text length
if "y_embedder.y_embedding" in ckpt:
if ckpt["y_embedder.y_embedding"].shape[
0] < model.y_embedder.y_embedding.shape[0]:
print(
"Extend y_embedding from %s to %s",
ckpt["y_embedder.y_embedding"].shape[0],
model.y_embedder.y_embedding.shape[0],
)
additional_length = model.y_embedder.y_embedding.shape[0] - ckpt[
"y_embedder.y_embedding"].shape[0]
new_y_embedding = torch.zeros(additional_length,
model.y_embedder.y_embedding.shape[1])
new_y_embedding[:] = ckpt["y_embedder.y_embedding"][-1]
ckpt["y_embedder.y_embedding"] = torch.cat(
[ckpt["y_embedder.y_embedding"], new_y_embedding], dim=0)
elif ckpt["y_embedder.y_embedding"].shape[
0] > model.y_embedder.y_embedding.shape[0]:
print(
"Shrink y_embedding from %s to %s",
ckpt["y_embedder.y_embedding"].shape[0],
model.y_embedder.y_embedding.shape[0],
)
ckpt["y_embedder.y_embedding"] = ckpt[
"y_embedder.y_embedding"][:model.y_embedder.y_embedding.
shape[0]]
# stdit3 special case
if type(model).__name__ == "STDiT3" and "PixArt-Sigma" in name:
ckpt_keys = list(ckpt.keys())
for key in ckpt_keys:
if "blocks." in key:
ckpt[key.replace("blocks.", "spatial_blocks.")] = ckpt[key]
del ckpt[key]
return ckpt
def download_model(model_name=None, local_path=None, url=None):
"""
Downloads a pre-trained DiT model from the web.
"""
if model_name is not None:
assert model_name in PRETRAINED_MODELS
local_path = f"pretrained_models/{model_name}"
web_path = PRETRAINED_MODELS[model_name]
else:
assert local_path is not None
assert url is not None
web_path = url
if not os.path.isfile(local_path):
os.makedirs("pretrained_models", exist_ok=True)
dir_name = os.path.dirname(local_path)
file_name = os.path.basename(local_path)
download_url(web_path, dir_name, file_name)
model = torch.load(local_path, map_location=lambda storage, loc: storage)
return model
def find_model(model_name, model=None):
"""
Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path.
"""
if model_name in PRETRAINED_MODELS: # Find/download our pre-trained DiT checkpoints
model_ckpt = download_model(model_name)
model_ckpt = reparameter(model_ckpt, model_name, model=model)
else: # Load a custom DiT checkpoint:
assert os.path.isfile(
model_name), f"Could not find DiT checkpoint at {model_name}"
model_ckpt = torch.load(model_name,
map_location=lambda storage, loc: storage)
model_ckpt = reparameter(model_ckpt, model_name, model=model)
return model_ckpt
def load_from_sharded_state_dict(model,
ckpt_path,
model_name="model",
strict=False):
ckpt_io = GeneralCheckpointIO()
ckpt_io.load_model(model,
os.path.join(ckpt_path, model_name),
strict=strict)
def load_checkpoint(model,
ckpt_path,
save_as_pt=False,
model_name="model",
strict=False):
if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"):
state_dict = find_model(ckpt_path, model=model)
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
strict=strict)
print("Missing keys: %s", missing_keys)
print("Unexpected keys: %s", unexpected_keys)
elif ckpt_path.endswith(".safetensors"):
state_dict = load_file(ckpt_path)
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
strict=False)
print(f"Missing keys: {missing_keys}")
print(f"Unexpected keys: {unexpected_keys}")
elif os.path.isdir(ckpt_path):
load_from_sharded_state_dict(model,
ckpt_path,
model_name,
strict=strict)
print("Model checkpoint loaded from %s", ckpt_path)
if save_as_pt:
save_path = os.path.join(ckpt_path, model_name + "_ckpt.pt")
torch.save(model.state_dict(), save_path)
print("Model checkpoint saved to %s", save_path)
else:
raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
def print_progress_bar(iteration, total, length=40):
iteration += 1
filled_length = int(length * iteration // total)
bar = '' * filled_length + '-' * (length - filled_length)
sys.stdout.write(f'\rDenoising steps: |{bar}| {iteration}/{total}')
sys.stdout.flush()
class PromptProcessor():
@staticmethod
def load_prompts(prompt_path, start_idx=None, end_idx=None):
with open(prompt_path, "r") as f:
prompts = [line.strip() for line in f.readlines()]
prompts = prompts[start_idx:end_idx]
return prompts
@staticmethod
def extract_json_from_prompts(prompts, reference, mask_strategy):
ret_prompts = []
for i, prompt in enumerate(prompts):
parts = re.split(r"(?=[{])", prompt)
assert len(parts) <= 2, f"Invalid prompt: {prompt}"
ret_prompts.append(parts[0])
if len(parts) > 1:
additional_info = json.loads(parts[1])
for key in additional_info:
assert key in ["reference_path",
"mask_strategy"], f"Invalid key: {key}"
if key == "reference_path":
reference[i] = additional_info[key]
elif key == "mask_strategy":
mask_strategy[i] = additional_info[key]
return ret_prompts, reference, mask_strategy
@staticmethod
def split_prompt(prompt_text):
if prompt_text.startswith("|0|"):
# this is for prompts which look like
# |0| a beautiful day |1| a sunny day |2| a rainy day
# we want to parse it into a list of prompts with the loop index
prompt_list = prompt_text.split("|")[1:]
text_list = []
loop_idx = []
for i in range(0, len(prompt_list), 2):
start_loop = int(prompt_list[i])
text = prompt_list[i + 1].strip()
text_list.append(text)
loop_idx.append(start_loop)
return text_list, loop_idx
else:
return [prompt_text], None
@staticmethod
def merge_prompt(text_list, loop_idx_list=None):
if loop_idx_list is None:
return text_list[0]
else:
prompt = ""
for i, text in enumerate(text_list):
prompt += f"|{loop_idx_list[i]}|{text}"
return prompt
@staticmethod
def extract_prompts_loop(prompts, num_loop):
ret_prompts = []
for prompt in prompts:
if prompt.startswith("|0|"):
prompt_list = prompt.split("|")[1:]
text_list = []
for i in range(0, len(prompt_list), 2):
start_loop = int(prompt_list[i])
text = prompt_list[i + 1]
end_loop = int(prompt_list[i + 2]) if i + 2 < len(
prompt_list) else num_loop + 1
text_list.extend([text] * (end_loop - start_loop))
prompt = text_list[num_loop]
ret_prompts.append(prompt)
return ret_prompts
@staticmethod
def append_score_to_prompts(prompts,
aes=None,
flow=None,
camera_motion=None):
new_prompts = []
for prompt in prompts:
new_prompt = prompt
if aes is not None and "aesthetic score:" not in prompt:
new_prompt = f"{new_prompt} aesthetic score: {aes:.1f}."
if flow is not None and "motion score:" not in prompt:
new_prompt = f"{new_prompt} motion score: {flow:.1f}."
if camera_motion is not None and "camera motion:" not in prompt:
new_prompt = f"{new_prompt} camera motion: {camera_motion}."
new_prompts.append(new_prompt)
return new_prompts
@classmethod
def basic_clean(cls, text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
@classmethod
def clean_caption(cls, caption):
BAD_PUNCT_REGEX = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" +
"\]" + "\[" + "\}" + "\{" + "\|" + "\\" +
"\/" + "\*" + r"]{1,}") # noqa
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)",
"", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(BAD_PUNCT_REGEX, r" ",
caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = cls.basic_clean(caption)
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(
r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "",
caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ",
caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
@classmethod
def text_preprocessing(cls, text, use_text_preprocessing: bool = True):
if use_text_preprocessing:
# The exact text cleaning as was in the training stage:
text = cls.clean_caption(text)
text = cls.clean_caption(text)
return text
else:
return text.lower().strip()
class VideoProcessor():
MASK_DEFAULT = ["0", "0", "0", "0", "1", "0"]
@staticmethod
def get_image_size(resolution, ar_ratio):
return get_image_size(resolution=resolution, ar_ratio=ar_ratio)
@staticmethod
def get_num_frames(num_frames):
return get_num_frames(num_frames=num_frames)
@staticmethod
def get_latent_size(vae, input_size):
return vae.get_latent_size(input_size=input_size)
@staticmethod
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size),
resample=Image.BOX)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(
round(x * scale) for x in pil_image.size),
resample=Image.BICUBIC)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y:crop_y + image_size,
crop_x:crop_x + image_size])
@staticmethod
def resize_crop_to_fill(pil_image, image_size):
w, h = pil_image.size # PIL is (W, H)
th, tw = image_size
rh, rw = th / h, tw / w
if rh > rw:
sh, sw = th, round(w * rh)
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = 0
j = int(round((sw - tw) / 2.0))
else:
sh, sw = round(h * rw), tw
image = pil_image.resize((sw, sh), Image.BICUBIC)
i = int(round((sh - th) / 2.0))
j = 0
arr = np.array(image)
assert i + th <= arr.shape[0] and j + tw <= arr.shape[1]
return Image.fromarray(arr[i:i + th, j:j + tw])
@classmethod
def get_transforms_image(cls, name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[
1], "Image size must be square for center crop"
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: cls.center_crop_arr(
pil_image, image_size[0])),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
inplace=True),
])
elif name == "resize_crop":
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: cls.resize_crop_to_fill(
pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
inplace=True),
])
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform
@classmethod
def get_transforms_video(cls, name="center", image_size=(256, 256)):
if name is None:
return None
elif name == "center":
assert image_size[0] == image_size[
1], "image_size must be square for center crop"
transform_video = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
# video_transforms.RandomHorizontalFlipVideo(),
video_transforms.UCFCenterCropVideo(image_size[0]),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
inplace=True),
])
elif name == "resize_crop":
transform_video = transforms.Compose([
video_transforms.ToTensorVideo(), # TCHW
video_transforms.ResizeCrop(image_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5],
inplace=True),
])
else:
raise NotImplementedError(f"Transform {name} not implemented")
return transform_video
@classmethod
def read_image_from_path(cls,
path,
transform=None,
transform_name="center",
num_frames=1,
image_size=(256, 256)):
image = pil_loader(path)
if transform is None:
transform = cls.get_transforms_image(image_size=image_size,
name=transform_name)
image = transform(image)
video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1)
video = video.permute(1, 0, 2, 3)
return video
@classmethod
def read_video_from_path(cls,
path,
transform=None,
transform_name="center",
image_size=(256, 256)):
vframes, aframes, info = read_video(filename=path,
pts_unit="sec",
output_format="TCHW")
if transform is None:
transform = cls.get_transforms_video(image_size=image_size,
name=transform_name)
video = transform(vframes) # T C H W
video = video.permute(1, 0, 2, 3)
return video
@classmethod
def read_from_path(cls, path, image_size, transform_name="center"):
if is_url(path):
path = download_url(path)
ext = os.path.splitext(path)[-1].lower()
if ext.lower() in VID_EXTENSIONS:
return cls.read_video_from_path(path,
image_size=image_size,
transform_name=transform_name)
else:
assert ext.lower(
) in IMG_EXTENSIONS, f"Unsupported file format: {ext}"
return cls.read_image_from_path(path,
image_size=image_size,
transform_name=transform_name)
@classmethod
def collect_references_batch(cls, vae, reference_paths, image_size):
refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
for reference_path in reference_paths:
if reference_path == "":
refs_x.append([])
continue
ref_path = reference_path.split(";")
ref = []
for r_path in ref_path:
r = cls.read_from_path(r_path,
image_size,
transform_name="resize_crop")
r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
r_x = r_x.squeeze(0)
ref.append(r_x)
refs_x.append(ref)
return refs_x
@staticmethod
def append_generated(vae, generated_video, refs_x, mask_strategy, loop_i,
condition_frame_length, condition_frame_edit):
ref_x = vae.encode(generated_video)
for j, refs in enumerate(refs_x):
if refs is None:
refs_x[j] = [ref_x[j]]
else:
refs.append(ref_x[j])
if mask_strategy[j] is None or mask_strategy[j] == "":
mask_strategy[j] = ""
else:
mask_strategy[j] += ";"
mask_strategy[
j] += f"{loop_i},{len(refs)-1},-{condition_frame_length},0,{condition_frame_length},{condition_frame_edit}"
return refs_x, mask_strategy
@classmethod
def parse_mask_strategy(cls, mask_strategy):
mask_batch = []
if mask_strategy == "" or mask_strategy is None:
return mask_batch
mask_strategy = mask_strategy.split(";")
for mask in mask_strategy:
mask_group = mask.split(",")
num_group = len(mask_group)
assert num_group >= 1 and num_group <= 6, f"Invalid mask strategy: {mask}"
mask_group.extend(cls.MASK_DEFAULT[num_group:])
for i in range(5):
mask_group[i] = int(mask_group[i])
mask_group[5] = float(mask_group[5])
mask_batch.append(mask_group)
return mask_batch
@classmethod
def find_nearest_point(cls, value, point, max_value):
t = value // point
if value % point > point / 2 and t < max_value // point - 1:
t += 1
return t * point
@classmethod
def apply_mask_strategy(cls, z, refs_x, mask_strategys, loop_i, align=None):
masks = []
no_mask = True
for i, mask_strategy in enumerate(mask_strategys):
no_mask = False
mask = torch.ones(z.shape[2], dtype=torch.float, device=z.device)
mask_strategy = cls.parse_mask_strategy(mask_strategy)
for mst in mask_strategy:
loop_id, m_id, m_ref_start, m_target_start, m_length, edit_ratio = mst
if loop_id != loop_i:
continue
ref = refs_x[i][m_id]
if m_ref_start < 0:
# ref: [C, T, H, W]
m_ref_start = ref.shape[1] + m_ref_start
if m_target_start < 0:
# z: [B, C, T, H, W]
m_target_start = z.shape[2] + m_target_start
if align is not None:
m_ref_start = cls.find_nearest_point(
m_ref_start, align, ref.shape[1])
m_target_start = cls.find_nearest_point(
m_target_start, align, z.shape[2])
m_length = min(m_length, z.shape[2] - m_target_start,
ref.shape[1] - m_ref_start)
z[i, :, m_target_start:m_target_start +
m_length] = ref[:, m_ref_start:m_ref_start + m_length]
mask[m_target_start:m_target_start + m_length] = edit_ratio
masks.append(mask)
if no_mask:
return None
masks = torch.stack(masks)
return masks
@staticmethod
def dframe_to_frame(num):
assert num % 5 == 0, f"Invalid num: {num}"
return num // 5 * 17
@staticmethod
def save_sample(x,
save_path=None,
fps=8,
normalize=True,
value_range=(-1, 1),
force_video=False,
verbose=True):
"""
Args:
x (Tensor): shape [C, T, H, W]
"""
assert x.ndim == 4
if not force_video and x.shape[1] == 1: # T = 1: save as image
save_path += ".png"
x = x.squeeze(1)
save_image([x],
save_path,
normalize=normalize,
value_range=value_range)
else:
save_path += ".mp4"
if normalize:
low, high = value_range
x.clamp_(min=low, max=high)
x.sub_(low).div_(max(high - low, 1e-5))
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to(
"cpu", torch.uint8)
write_video(save_path, x, fps=fps, video_codec="h264")
if verbose:
print(f"Saved to {save_path}")
return save_path
@staticmethod
def add_watermark(
input_video_path,
watermark_image_path="./assets/images/watermark/watermark.png",
output_video_path=None):
# execute this command in terminal with subprocess
# return if the process is successful
if output_video_path is None:
output_video_path = input_video_path.replace(
".mp4", "_watermark.mp4")
cmd = f'ffmpeg -y -i {input_video_path} -i {watermark_image_path}'
cmd += f'-filter_complex "[1][0]scale2ref=oh*mdar:ih*0.1[logo][video];[video][logo]overlay" {output_video_path}'
import subprocess
exit_code = subprocess.call(cmd, shell=True)
is_success = exit_code == 0
return is_success
class DataProcessor():
def __init__(self, text_encoder, vae):
self._text_encoder = text_encoder
self._vae = vae
@staticmethod
def set_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
return seed
@staticmethod
def prepare_multi_resolution_info(model_type, batch_size, image_size,
num_frames, fps, device, dtype):
IMG_FPS = 120
ret = {}
if model_type in ["STDiT3", "OpenSora"]:
fps = fps if num_frames > 1 else IMG_FPS
fps = torch.tensor([fps], device=device,
dtype=dtype).repeat(batch_size)
height = torch.tensor([image_size[0]], device=device,
dtype=dtype).repeat(batch_size)
width = torch.tensor([image_size[1]], device=device,
dtype=dtype).repeat(batch_size)
num_frames = torch.tensor([num_frames], device=device,
dtype=dtype).repeat(batch_size)
ar = torch.tensor([image_size[0] / image_size[1]],
device=device,
dtype=dtype).repeat(batch_size)
ret = dict(height=height,
width=width,
num_frames=num_frames,
ar=ar,
fps=fps)
else:
raise NotImplementedError(f"Model type is {model_type}")
return ret
@staticmethod
def get_save_path_name(
save_dir,
sample_name=None, # prefix
sample_idx=None, # sample index
prompt=None, # used prompt
prompt_as_path=False, # use prompt as path
num_sample=1, # number of samples to generate for one prompt
k=None, # kth sample
):
if sample_name is None:
sample_name = "" if prompt_as_path else "sample"
sample_name_suffix = prompt if prompt_as_path else f"_{sample_idx:04d}"
save_path = os.path.join(save_dir, f"{sample_name}{sample_name_suffix}")
if num_sample != 1:
save_path = f"{save_path}-{k}"
return save_path
@staticmethod
def load_prompts(prompt_path, start_idx=None, end_idx=None):
return PromptProcessor.load_prompts(prompt_path=prompt_path,
start_idx=start_idx,
end_idx=end_idx)
@staticmethod
def extract_json_from_prompts(prompts, reference, mask_strategy):
return PromptProcessor.extract_json_from_prompts(
prompts=prompts, reference=reference, mask_strategy=mask_strategy)
@staticmethod
def split_prompt(prompt_text):
return PromptProcessor.split_prompt(prompt_text=prompt_text)
@staticmethod
def merge_prompt(text_list, loop_idx_list=None):
return PromptProcessor.merge_prompt(text_list=text_list,
loop_idx_list=loop_idx_list)
@staticmethod
def extract_prompts_loop(prompts, num_loop):
return PromptProcessor.extract_prompts_loop(prompts=prompts,
num_loop=num_loop)
@staticmethod
def append_score_to_prompts(prompts,
aes=None,
flow=None,
camera_motion=None):
return PromptProcessor.append_score_to_prompts(
prompts=prompts, aes=aes, flow=flow, camera_motion=camera_motion)
@staticmethod
def text_preprocessing(text, use_text_preprocessing: bool = True):
return PromptProcessor.text_preprocessing(
text=text, use_text_preprocessing=use_text_preprocessing)
@staticmethod
def get_image_size(resolution, ar_ratio):
return VideoProcessor.get_image_size(resolution=resolution,
ar_ratio=ar_ratio)
@staticmethod
def get_num_frames(num_frames):
return VideoProcessor.get_num_frames(num_frames=num_frames)
def get_latent_size(self, input_size):
return VideoProcessor.get_latent_size(self._vae, input_size=input_size)
def collect_references_batch(self, reference_paths, image_size):
return VideoProcessor.collect_references_batch(
self._vae, reference_paths=reference_paths, image_size=image_size)
def append_generated(self, generated_video, refs_x, mask_strategy, loop_i,
condition_frame_length, condition_frame_edit):
return VideoProcessor.append_generated(
self._vae,
generated_video=generated_video,
refs_x=refs_x,
mask_strategy=mask_strategy,
loop_i=loop_i,
condition_frame_length=condition_frame_length,
condition_frame_edit=condition_frame_edit)
@staticmethod
def apply_mask_strategy(z, refs_x, mask_strategys, loop_i, align=None):
return VideoProcessor.apply_mask_strategy(z=z,
refs_x=refs_x,
mask_strategys=mask_strategys,
loop_i=loop_i,
align=align)
@staticmethod
def dframe_to_frame(num):
return VideoProcessor.dframe_to_frame(num=num)
@staticmethod
def save_sample(x,
save_path=None,
fps=8,
normalize=True,
value_range=(-1, 1),
force_video=False,
verbose=True):
return VideoProcessor.save_sample(x=x,
save_path=save_path,
fps=fps,
normalize=normalize,
value_range=value_range,
force_video=force_video,
verbose=verbose)
@staticmethod
def add_watermark(
input_video_path,
watermark_image_path="./assets/images/watermark/watermark.png",
output_video_path=None):
return VideoProcessor.add_watermark(
input_video_path=input_video_path,
watermark_image_path=watermark_image_path,
output_video_path=output_video_path)