be4afa0bb4
* find & replace all FloatTensors to Tensor * apply formatting * Update torch.FloatTensor to torch.Tensor in the remaining files * formatting * Fix the rest of the places where FloatTensor is used as well as in documentation * formatting * Update new file from FloatTensor to Tensor
1038 lines
38 KiB
Python
1038 lines
38 KiB
Python
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ..utils import deprecate
|
|
from .activations import get_activation
|
|
from .attention_processor import Attention
|
|
|
|
|
|
def get_timestep_embedding(
|
|
timesteps: torch.Tensor,
|
|
embedding_dim: int,
|
|
flip_sin_to_cos: bool = False,
|
|
downscale_freq_shift: float = 1,
|
|
scale: float = 1,
|
|
max_period: int = 10000,
|
|
):
|
|
"""
|
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
|
|
|
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
|
These may be fractional.
|
|
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
|
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
|
"""
|
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
|
|
|
half_dim = embedding_dim // 2
|
|
exponent = -math.log(max_period) * torch.arange(
|
|
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
|
)
|
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
|
|
|
emb = torch.exp(exponent)
|
|
emb = timesteps[:, None].float() * emb[None, :]
|
|
|
|
# scale embeddings
|
|
emb = scale * emb
|
|
|
|
# concat sine and cosine embeddings
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
|
|
|
# flip sine and cosine embeddings
|
|
if flip_sin_to_cos:
|
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
|
|
|
# zero pad
|
|
if embedding_dim % 2 == 1:
|
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
|
return emb
|
|
|
|
|
|
def get_2d_sincos_pos_embed(
|
|
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
|
):
|
|
"""
|
|
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
|
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
|
"""
|
|
if isinstance(grid_size, int):
|
|
grid_size = (grid_size, grid_size)
|
|
|
|
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
|
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
|
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
|
grid = np.stack(grid, axis=0)
|
|
|
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
|
if cls_token and extra_tokens > 0:
|
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
|
return pos_embed
|
|
|
|
|
|
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
|
if embed_dim % 2 != 0:
|
|
raise ValueError("embed_dim must be divisible by 2")
|
|
|
|
# use half of dimensions to encode grid_h
|
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
|
|
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
|
return emb
|
|
|
|
|
|
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
"""
|
|
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
|
"""
|
|
if embed_dim % 2 != 0:
|
|
raise ValueError("embed_dim must be divisible by 2")
|
|
|
|
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
|
omega /= embed_dim / 2.0
|
|
omega = 1.0 / 10000**omega # (D/2,)
|
|
|
|
pos = pos.reshape(-1) # (M,)
|
|
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
|
|
|
emb_sin = np.sin(out) # (M, D/2)
|
|
emb_cos = np.cos(out) # (M, D/2)
|
|
|
|
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
|
return emb
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
"""2D Image to Patch Embedding"""
|
|
|
|
def __init__(
|
|
self,
|
|
height=224,
|
|
width=224,
|
|
patch_size=16,
|
|
in_channels=3,
|
|
embed_dim=768,
|
|
layer_norm=False,
|
|
flatten=True,
|
|
bias=True,
|
|
interpolation_scale=1,
|
|
):
|
|
super().__init__()
|
|
|
|
num_patches = (height // patch_size) * (width // patch_size)
|
|
self.flatten = flatten
|
|
self.layer_norm = layer_norm
|
|
|
|
self.proj = nn.Conv2d(
|
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
|
)
|
|
if layer_norm:
|
|
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
|
else:
|
|
self.norm = None
|
|
|
|
self.patch_size = patch_size
|
|
# See:
|
|
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
|
|
self.height, self.width = height // patch_size, width // patch_size
|
|
self.base_size = height // patch_size
|
|
self.interpolation_scale = interpolation_scale
|
|
pos_embed = get_2d_sincos_pos_embed(
|
|
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
|
)
|
|
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
|
|
|
def forward(self, latent):
|
|
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
|
|
|
latent = self.proj(latent)
|
|
if self.flatten:
|
|
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
|
if self.layer_norm:
|
|
latent = self.norm(latent)
|
|
|
|
# Interpolate positional embeddings if needed.
|
|
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
|
|
if self.height != height or self.width != width:
|
|
pos_embed = get_2d_sincos_pos_embed(
|
|
embed_dim=self.pos_embed.shape[-1],
|
|
grid_size=(height, width),
|
|
base_size=self.base_size,
|
|
interpolation_scale=self.interpolation_scale,
|
|
)
|
|
pos_embed = torch.from_numpy(pos_embed)
|
|
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
|
|
else:
|
|
pos_embed = self.pos_embed
|
|
|
|
return (latent + pos_embed).to(latent.dtype)
|
|
|
|
|
|
class TimestepEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
time_embed_dim: int,
|
|
act_fn: str = "silu",
|
|
out_dim: int = None,
|
|
post_act_fn: Optional[str] = None,
|
|
cond_proj_dim=None,
|
|
sample_proj_bias=True,
|
|
):
|
|
super().__init__()
|
|
|
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
|
|
|
if cond_proj_dim is not None:
|
|
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
|
else:
|
|
self.cond_proj = None
|
|
|
|
self.act = get_activation(act_fn)
|
|
|
|
if out_dim is not None:
|
|
time_embed_dim_out = out_dim
|
|
else:
|
|
time_embed_dim_out = time_embed_dim
|
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
|
|
|
if post_act_fn is None:
|
|
self.post_act = None
|
|
else:
|
|
self.post_act = get_activation(post_act_fn)
|
|
|
|
def forward(self, sample, condition=None):
|
|
if condition is not None:
|
|
sample = sample + self.cond_proj(condition)
|
|
sample = self.linear_1(sample)
|
|
|
|
if self.act is not None:
|
|
sample = self.act(sample)
|
|
|
|
sample = self.linear_2(sample)
|
|
|
|
if self.post_act is not None:
|
|
sample = self.post_act(sample)
|
|
return sample
|
|
|
|
|
|
class Timesteps(nn.Module):
|
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
|
super().__init__()
|
|
self.num_channels = num_channels
|
|
self.flip_sin_to_cos = flip_sin_to_cos
|
|
self.downscale_freq_shift = downscale_freq_shift
|
|
|
|
def forward(self, timesteps):
|
|
t_emb = get_timestep_embedding(
|
|
timesteps,
|
|
self.num_channels,
|
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
|
downscale_freq_shift=self.downscale_freq_shift,
|
|
)
|
|
return t_emb
|
|
|
|
|
|
class GaussianFourierProjection(nn.Module):
|
|
"""Gaussian Fourier embeddings for noise levels."""
|
|
|
|
def __init__(
|
|
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
|
|
):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
|
self.log = log
|
|
self.flip_sin_to_cos = flip_sin_to_cos
|
|
|
|
if set_W_to_weight:
|
|
# to delete later
|
|
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
|
|
|
self.weight = self.W
|
|
|
|
def forward(self, x):
|
|
if self.log:
|
|
x = torch.log(x)
|
|
|
|
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
|
|
|
if self.flip_sin_to_cos:
|
|
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
|
else:
|
|
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
|
return out
|
|
|
|
|
|
class SinusoidalPositionalEmbedding(nn.Module):
|
|
"""Apply positional information to a sequence of embeddings.
|
|
|
|
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
|
|
them
|
|
|
|
Args:
|
|
embed_dim: (int): Dimension of the positional embedding.
|
|
max_seq_length: Maximum sequence length to apply positional embeddings
|
|
|
|
"""
|
|
|
|
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
|
super().__init__()
|
|
position = torch.arange(max_seq_length).unsqueeze(1)
|
|
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
|
|
pe = torch.zeros(1, max_seq_length, embed_dim)
|
|
pe[0, :, 0::2] = torch.sin(position * div_term)
|
|
pe[0, :, 1::2] = torch.cos(position * div_term)
|
|
self.register_buffer("pe", pe)
|
|
|
|
def forward(self, x):
|
|
_, seq_length, _ = x.shape
|
|
x = x + self.pe[:, :seq_length]
|
|
return x
|
|
|
|
|
|
class ImagePositionalEmbeddings(nn.Module):
|
|
"""
|
|
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
|
height and width of the latent space.
|
|
|
|
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
|
|
|
For VQ-diffusion:
|
|
|
|
Output vector embeddings are used as input for the transformer.
|
|
|
|
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
|
|
|
Args:
|
|
num_embed (`int`):
|
|
Number of embeddings for the latent pixels embeddings.
|
|
height (`int`):
|
|
Height of the latent image i.e. the number of height embeddings.
|
|
width (`int`):
|
|
Width of the latent image i.e. the number of width embeddings.
|
|
embed_dim (`int`):
|
|
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_embed: int,
|
|
height: int,
|
|
width: int,
|
|
embed_dim: int,
|
|
):
|
|
super().__init__()
|
|
|
|
self.height = height
|
|
self.width = width
|
|
self.num_embed = num_embed
|
|
self.embed_dim = embed_dim
|
|
|
|
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
|
self.height_emb = nn.Embedding(self.height, embed_dim)
|
|
self.width_emb = nn.Embedding(self.width, embed_dim)
|
|
|
|
def forward(self, index):
|
|
emb = self.emb(index)
|
|
|
|
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
|
|
|
# 1 x H x D -> 1 x H x 1 x D
|
|
height_emb = height_emb.unsqueeze(2)
|
|
|
|
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
|
|
|
# 1 x W x D -> 1 x 1 x W x D
|
|
width_emb = width_emb.unsqueeze(1)
|
|
|
|
pos_emb = height_emb + width_emb
|
|
|
|
# 1 x H x W x D -> 1 x L xD
|
|
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
|
|
|
emb = emb + pos_emb[:, : emb.shape[1], :]
|
|
|
|
return emb
|
|
|
|
|
|
class LabelEmbedding(nn.Module):
|
|
"""
|
|
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
|
|
|
Args:
|
|
num_classes (`int`): The number of classes.
|
|
hidden_size (`int`): The size of the vector embeddings.
|
|
dropout_prob (`float`): The probability of dropping a label.
|
|
"""
|
|
|
|
def __init__(self, num_classes, hidden_size, dropout_prob):
|
|
super().__init__()
|
|
use_cfg_embedding = dropout_prob > 0
|
|
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
|
self.num_classes = num_classes
|
|
self.dropout_prob = dropout_prob
|
|
|
|
def token_drop(self, labels, force_drop_ids=None):
|
|
"""
|
|
Drops labels to enable classifier-free guidance.
|
|
"""
|
|
if force_drop_ids is None:
|
|
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
|
else:
|
|
drop_ids = torch.tensor(force_drop_ids == 1)
|
|
labels = torch.where(drop_ids, self.num_classes, labels)
|
|
return labels
|
|
|
|
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
|
|
use_dropout = self.dropout_prob > 0
|
|
if (self.training and use_dropout) or (force_drop_ids is not None):
|
|
labels = self.token_drop(labels, force_drop_ids)
|
|
embeddings = self.embedding_table(labels)
|
|
return embeddings
|
|
|
|
|
|
class TextImageProjection(nn.Module):
|
|
def __init__(
|
|
self,
|
|
text_embed_dim: int = 1024,
|
|
image_embed_dim: int = 768,
|
|
cross_attention_dim: int = 768,
|
|
num_image_text_embeds: int = 10,
|
|
):
|
|
super().__init__()
|
|
|
|
self.num_image_text_embeds = num_image_text_embeds
|
|
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
|
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
|
|
|
|
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
|
batch_size = text_embeds.shape[0]
|
|
|
|
# image
|
|
image_text_embeds = self.image_embeds(image_embeds)
|
|
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
|
|
|
# text
|
|
text_embeds = self.text_proj(text_embeds)
|
|
|
|
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
|
|
|
|
|
class ImageProjection(nn.Module):
|
|
def __init__(
|
|
self,
|
|
image_embed_dim: int = 768,
|
|
cross_attention_dim: int = 768,
|
|
num_image_text_embeds: int = 32,
|
|
):
|
|
super().__init__()
|
|
|
|
self.num_image_text_embeds = num_image_text_embeds
|
|
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
|
self.norm = nn.LayerNorm(cross_attention_dim)
|
|
|
|
def forward(self, image_embeds: torch.Tensor):
|
|
batch_size = image_embeds.shape[0]
|
|
|
|
# image
|
|
image_embeds = self.image_embeds(image_embeds)
|
|
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
|
image_embeds = self.norm(image_embeds)
|
|
return image_embeds
|
|
|
|
|
|
class IPAdapterFullImageProjection(nn.Module):
|
|
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
|
|
super().__init__()
|
|
from .attention import FeedForward
|
|
|
|
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
|
|
self.norm = nn.LayerNorm(cross_attention_dim)
|
|
|
|
def forward(self, image_embeds: torch.Tensor):
|
|
return self.norm(self.ff(image_embeds))
|
|
|
|
|
|
class IPAdapterFaceIDImageProjection(nn.Module):
|
|
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
|
|
super().__init__()
|
|
from .attention import FeedForward
|
|
|
|
self.num_tokens = num_tokens
|
|
self.cross_attention_dim = cross_attention_dim
|
|
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
|
|
self.norm = nn.LayerNorm(cross_attention_dim)
|
|
|
|
def forward(self, image_embeds: torch.Tensor):
|
|
x = self.ff(image_embeds)
|
|
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
|
return self.norm(x)
|
|
|
|
|
|
class CombinedTimestepLabelEmbeddings(nn.Module):
|
|
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
|
super().__init__()
|
|
|
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
|
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
|
|
|
def forward(self, timestep, class_labels, hidden_dtype=None):
|
|
timesteps_proj = self.time_proj(timestep)
|
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
|
|
|
class_labels = self.class_embedder(class_labels) # (N, D)
|
|
|
|
conditioning = timesteps_emb + class_labels # (N, D)
|
|
|
|
return conditioning
|
|
|
|
|
|
class TextTimeEmbedding(nn.Module):
|
|
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
|
super().__init__()
|
|
self.norm1 = nn.LayerNorm(encoder_dim)
|
|
self.pool = AttentionPooling(num_heads, encoder_dim)
|
|
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
|
self.norm2 = nn.LayerNorm(time_embed_dim)
|
|
|
|
def forward(self, hidden_states):
|
|
hidden_states = self.norm1(hidden_states)
|
|
hidden_states = self.pool(hidden_states)
|
|
hidden_states = self.proj(hidden_states)
|
|
hidden_states = self.norm2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class TextImageTimeEmbedding(nn.Module):
|
|
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
|
super().__init__()
|
|
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
|
|
self.text_norm = nn.LayerNorm(time_embed_dim)
|
|
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
|
|
|
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
|
# text
|
|
time_text_embeds = self.text_proj(text_embeds)
|
|
time_text_embeds = self.text_norm(time_text_embeds)
|
|
|
|
# image
|
|
time_image_embeds = self.image_proj(image_embeds)
|
|
|
|
return time_image_embeds + time_text_embeds
|
|
|
|
|
|
class ImageTimeEmbedding(nn.Module):
|
|
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
|
super().__init__()
|
|
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
|
self.image_norm = nn.LayerNorm(time_embed_dim)
|
|
|
|
def forward(self, image_embeds: torch.Tensor):
|
|
# image
|
|
time_image_embeds = self.image_proj(image_embeds)
|
|
time_image_embeds = self.image_norm(time_image_embeds)
|
|
return time_image_embeds
|
|
|
|
|
|
class ImageHintTimeEmbedding(nn.Module):
|
|
def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
|
super().__init__()
|
|
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
|
self.image_norm = nn.LayerNorm(time_embed_dim)
|
|
self.input_hint_block = nn.Sequential(
|
|
nn.Conv2d(3, 16, 3, padding=1),
|
|
nn.SiLU(),
|
|
nn.Conv2d(16, 16, 3, padding=1),
|
|
nn.SiLU(),
|
|
nn.Conv2d(16, 32, 3, padding=1, stride=2),
|
|
nn.SiLU(),
|
|
nn.Conv2d(32, 32, 3, padding=1),
|
|
nn.SiLU(),
|
|
nn.Conv2d(32, 96, 3, padding=1, stride=2),
|
|
nn.SiLU(),
|
|
nn.Conv2d(96, 96, 3, padding=1),
|
|
nn.SiLU(),
|
|
nn.Conv2d(96, 256, 3, padding=1, stride=2),
|
|
nn.SiLU(),
|
|
nn.Conv2d(256, 4, 3, padding=1),
|
|
)
|
|
|
|
def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
|
|
# image
|
|
time_image_embeds = self.image_proj(image_embeds)
|
|
time_image_embeds = self.image_norm(time_image_embeds)
|
|
hint = self.input_hint_block(hint)
|
|
return time_image_embeds, hint
|
|
|
|
|
|
class AttentionPooling(nn.Module):
|
|
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
|
|
|
def __init__(self, num_heads, embed_dim, dtype=None):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
|
self.num_heads = num_heads
|
|
self.dim_per_head = embed_dim // self.num_heads
|
|
|
|
def forward(self, x):
|
|
bs, length, width = x.size()
|
|
|
|
def shape(x):
|
|
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
|
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
|
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
|
x = x.transpose(1, 2)
|
|
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
|
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
|
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
|
x = x.transpose(1, 2)
|
|
return x
|
|
|
|
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
|
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
|
|
|
# (bs*n_heads, class_token_length, dim_per_head)
|
|
q = shape(self.q_proj(class_token))
|
|
# (bs*n_heads, length+class_token_length, dim_per_head)
|
|
k = shape(self.k_proj(x))
|
|
v = shape(self.v_proj(x))
|
|
|
|
# (bs*n_heads, class_token_length, length+class_token_length):
|
|
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
|
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
|
|
|
# (bs*n_heads, dim_per_head, class_token_length)
|
|
a = torch.einsum("bts,bcs->bct", weight, v)
|
|
|
|
# (bs, length+1, width)
|
|
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
|
|
|
return a[:, 0, :] # cls_token
|
|
|
|
|
|
def get_fourier_embeds_from_boundingbox(embed_dim, box):
|
|
"""
|
|
Args:
|
|
embed_dim: int
|
|
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
|
|
Returns:
|
|
[B x N x embed_dim] tensor of positional embeddings
|
|
"""
|
|
|
|
batch_size, num_boxes = box.shape[:2]
|
|
|
|
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
|
|
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
|
|
emb = emb * box.unsqueeze(-1)
|
|
|
|
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
|
|
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
|
|
|
|
return emb
|
|
|
|
|
|
class GLIGENTextBoundingboxProjection(nn.Module):
|
|
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
|
|
super().__init__()
|
|
self.positive_len = positive_len
|
|
self.out_dim = out_dim
|
|
|
|
self.fourier_embedder_dim = fourier_freqs
|
|
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
|
|
|
|
if isinstance(out_dim, tuple):
|
|
out_dim = out_dim[0]
|
|
|
|
if feature_type == "text-only":
|
|
self.linears = nn.Sequential(
|
|
nn.Linear(self.positive_len + self.position_dim, 512),
|
|
nn.SiLU(),
|
|
nn.Linear(512, 512),
|
|
nn.SiLU(),
|
|
nn.Linear(512, out_dim),
|
|
)
|
|
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
|
|
|
elif feature_type == "text-image":
|
|
self.linears_text = nn.Sequential(
|
|
nn.Linear(self.positive_len + self.position_dim, 512),
|
|
nn.SiLU(),
|
|
nn.Linear(512, 512),
|
|
nn.SiLU(),
|
|
nn.Linear(512, out_dim),
|
|
)
|
|
self.linears_image = nn.Sequential(
|
|
nn.Linear(self.positive_len + self.position_dim, 512),
|
|
nn.SiLU(),
|
|
nn.Linear(512, 512),
|
|
nn.SiLU(),
|
|
nn.Linear(512, out_dim),
|
|
)
|
|
self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
|
self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
|
|
|
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
|
|
|
|
def forward(
|
|
self,
|
|
boxes,
|
|
masks,
|
|
positive_embeddings=None,
|
|
phrases_masks=None,
|
|
image_masks=None,
|
|
phrases_embeddings=None,
|
|
image_embeddings=None,
|
|
):
|
|
masks = masks.unsqueeze(-1)
|
|
|
|
# embedding position (it may includes padding as placeholder)
|
|
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
|
|
|
|
# learnable null embedding
|
|
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
|
|
|
# replace padding with learnable null embedding
|
|
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
|
|
|
# positionet with text only information
|
|
if positive_embeddings is not None:
|
|
# learnable null embedding
|
|
positive_null = self.null_positive_feature.view(1, 1, -1)
|
|
|
|
# replace padding with learnable null embedding
|
|
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
|
|
|
|
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
|
|
|
# positionet with text and image infomation
|
|
else:
|
|
phrases_masks = phrases_masks.unsqueeze(-1)
|
|
image_masks = image_masks.unsqueeze(-1)
|
|
|
|
# learnable null embedding
|
|
text_null = self.null_text_feature.view(1, 1, -1)
|
|
image_null = self.null_image_feature.view(1, 1, -1)
|
|
|
|
# replace padding with learnable null embedding
|
|
phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
|
|
image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
|
|
|
|
objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
|
|
objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
|
|
objs = torch.cat([objs_text, objs_image], dim=1)
|
|
|
|
return objs
|
|
|
|
|
|
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
|
"""
|
|
For PixArt-Alpha.
|
|
|
|
Reference:
|
|
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
|
"""
|
|
|
|
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
|
super().__init__()
|
|
|
|
self.outdim = size_emb_dim
|
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
|
|
|
self.use_additional_conditions = use_additional_conditions
|
|
if use_additional_conditions:
|
|
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
|
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
|
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
|
|
|
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
|
timesteps_proj = self.time_proj(timestep)
|
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
|
|
|
if self.use_additional_conditions:
|
|
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
|
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
|
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
|
|
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
|
|
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
|
|
else:
|
|
conditioning = timesteps_emb
|
|
|
|
return conditioning
|
|
|
|
|
|
class PixArtAlphaTextProjection(nn.Module):
|
|
"""
|
|
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
|
|
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
|
"""
|
|
|
|
def __init__(self, in_features, hidden_size, num_tokens=120):
|
|
super().__init__()
|
|
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
|
self.act_1 = nn.GELU(approximate="tanh")
|
|
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
|
|
|
|
def forward(self, caption):
|
|
hidden_states = self.linear_1(caption)
|
|
hidden_states = self.act_1(hidden_states)
|
|
hidden_states = self.linear_2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class IPAdapterPlusImageProjection(nn.Module):
|
|
"""Resampler of IP-Adapter Plus.
|
|
|
|
Args:
|
|
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
|
that is the same
|
|
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
|
hidden_dims (int):
|
|
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
|
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
|
Defaults to 16. num_queries (int):
|
|
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
|
of feedforward network hidden
|
|
layer channels. Defaults to 4.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dims: int = 768,
|
|
output_dims: int = 1024,
|
|
hidden_dims: int = 1280,
|
|
depth: int = 4,
|
|
dim_head: int = 64,
|
|
heads: int = 16,
|
|
num_queries: int = 8,
|
|
ffn_ratio: float = 4,
|
|
) -> None:
|
|
super().__init__()
|
|
from .attention import FeedForward # Lazy import to avoid circular import
|
|
|
|
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
|
|
|
|
self.proj_in = nn.Linear(embed_dims, hidden_dims)
|
|
|
|
self.proj_out = nn.Linear(hidden_dims, output_dims)
|
|
self.norm_out = nn.LayerNorm(output_dims)
|
|
|
|
self.layers = nn.ModuleList([])
|
|
for _ in range(depth):
|
|
self.layers.append(
|
|
nn.ModuleList(
|
|
[
|
|
nn.LayerNorm(hidden_dims),
|
|
nn.LayerNorm(hidden_dims),
|
|
Attention(
|
|
query_dim=hidden_dims,
|
|
dim_head=dim_head,
|
|
heads=heads,
|
|
out_bias=False,
|
|
),
|
|
nn.Sequential(
|
|
nn.LayerNorm(hidden_dims),
|
|
FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
|
),
|
|
]
|
|
)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Forward pass.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input Tensor.
|
|
Returns:
|
|
torch.Tensor: Output Tensor.
|
|
"""
|
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
|
|
|
x = self.proj_in(x)
|
|
|
|
for ln0, ln1, attn, ff in self.layers:
|
|
residual = latents
|
|
|
|
encoder_hidden_states = ln0(x)
|
|
latents = ln1(latents)
|
|
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
|
latents = attn(latents, encoder_hidden_states) + residual
|
|
latents = ff(latents) + latents
|
|
|
|
latents = self.proj_out(latents)
|
|
return self.norm_out(latents)
|
|
|
|
|
|
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embed_dims: int = 768,
|
|
dim_head: int = 64,
|
|
heads: int = 16,
|
|
ffn_ratio: float = 4,
|
|
) -> None:
|
|
super().__init__()
|
|
from .attention import FeedForward
|
|
|
|
self.ln0 = nn.LayerNorm(embed_dims)
|
|
self.ln1 = nn.LayerNorm(embed_dims)
|
|
self.attn = Attention(
|
|
query_dim=embed_dims,
|
|
dim_head=dim_head,
|
|
heads=heads,
|
|
out_bias=False,
|
|
)
|
|
self.ff = nn.Sequential(
|
|
nn.LayerNorm(embed_dims),
|
|
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
|
)
|
|
|
|
def forward(self, x, latents, residual):
|
|
encoder_hidden_states = self.ln0(x)
|
|
latents = self.ln1(latents)
|
|
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
|
latents = self.attn(latents, encoder_hidden_states) + residual
|
|
latents = self.ff(latents) + latents
|
|
return latents
|
|
|
|
|
|
class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
|
"""FacePerceiverResampler of IP-Adapter Plus.
|
|
|
|
Args:
|
|
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
|
that is the same
|
|
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
|
hidden_dims (int):
|
|
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
|
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
|
Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
|
|
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
|
layer channels. Defaults to 4.
|
|
ffproj_ratio (float): The expansion ratio of feedforward network hidden
|
|
layer channels (for ID embeddings). Defaults to 4.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dims: int = 768,
|
|
output_dims: int = 768,
|
|
hidden_dims: int = 1280,
|
|
id_embeddings_dim: int = 512,
|
|
depth: int = 4,
|
|
dim_head: int = 64,
|
|
heads: int = 16,
|
|
num_tokens: int = 4,
|
|
num_queries: int = 8,
|
|
ffn_ratio: float = 4,
|
|
ffproj_ratio: int = 2,
|
|
) -> None:
|
|
super().__init__()
|
|
from .attention import FeedForward
|
|
|
|
self.num_tokens = num_tokens
|
|
self.embed_dim = embed_dims
|
|
self.clip_embeds = None
|
|
self.shortcut = False
|
|
self.shortcut_scale = 1.0
|
|
|
|
self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
|
|
self.norm = nn.LayerNorm(embed_dims)
|
|
|
|
self.proj_in = nn.Linear(hidden_dims, embed_dims)
|
|
|
|
self.proj_out = nn.Linear(embed_dims, output_dims)
|
|
self.norm_out = nn.LayerNorm(output_dims)
|
|
|
|
self.layers = nn.ModuleList(
|
|
[IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
|
)
|
|
|
|
def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
|
|
"""Forward pass.
|
|
|
|
Args:
|
|
id_embeds (torch.Tensor): Input Tensor (ID embeds).
|
|
Returns:
|
|
torch.Tensor: Output Tensor.
|
|
"""
|
|
id_embeds = id_embeds.to(self.clip_embeds.dtype)
|
|
id_embeds = self.proj(id_embeds)
|
|
id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
|
|
id_embeds = self.norm(id_embeds)
|
|
latents = id_embeds
|
|
|
|
clip_embeds = self.proj_in(self.clip_embeds)
|
|
x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
|
|
|
|
for block in self.layers:
|
|
residual = latents
|
|
latents = block(x, latents, residual)
|
|
|
|
latents = self.proj_out(latents)
|
|
out = self.norm_out(latents)
|
|
if self.shortcut:
|
|
out = id_embeds + self.shortcut_scale * out
|
|
return out
|
|
|
|
|
|
class MultiIPAdapterImageProjection(nn.Module):
|
|
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
|
super().__init__()
|
|
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
|
|
|
def forward(self, image_embeds: List[torch.Tensor]):
|
|
projected_image_embeds = []
|
|
|
|
# currently, we accept `image_embeds` as
|
|
# 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
|
|
# 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
|
|
if not isinstance(image_embeds, list):
|
|
deprecation_message = (
|
|
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
|
|
" Please make sure to update your script to pass `image_embeds` as a list of tensors to supress this warning."
|
|
)
|
|
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
|
|
image_embeds = [image_embeds.unsqueeze(1)]
|
|
|
|
if len(image_embeds) != len(self.image_projection_layers):
|
|
raise ValueError(
|
|
f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
|
|
)
|
|
|
|
for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
|
|
batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
|
|
image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
|
|
image_embed = image_projection_layer(image_embed)
|
|
image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
|
|
|
|
projected_image_embeds.append(image_embed)
|
|
|
|
return projected_image_embeds
|