d61889fc17
* init pixart alpha pipeline * fix: import * script * script * script * add: vae to the pipeline * add: vae_scale_factor * add: checkpoint_path * clean conversion script a bit. * size embeddings. * fix: size embedding * update scrip * support for interpolation of position embedding. * support for conditioning. * .. * .. * .. * final layer * final layer * align if encode_prompt * support for caption embedding * refactor * refactor * refactor * start cross attention * start cross attention * cross_attention_dim * cross * cross * support for resolution and aspect_ratio * support for caption projection * refactor patch embeddings * batch_size * up * commit * commit * commit. * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze * squeeze. * squeeze. * fix final block./ * fix final block./ * fix final block./ * clean * fix: interpolation scale. * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging' * debugging * debugging * debugging * debugging * debugging * debugging * debugging * make --checkpoint_path non-required. * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * remove num_tokens * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * timesteps -> timestep * debug * debug * update conversion script. * update conversion script. * update conversion script. * debug * debug * debug * clean * debug * debug * debug * debug * debug * debug * debug * debug * deug * debug * debug * debug * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * clean * fix * fix * boom * boom * some changes * boom * save * up * remove i * fix more tests * DPMSolverMultistepScheduler * fix * offloading * fix conversion script * fix conversion script * remove print * remove support for negative prompt embeds. * typo. * remove extra kwargs * bring conversion script to where it was * fix * trying mu luck * trying my luck again * again * again * again * clean up * up * up * update example * support for 512 * remove spacing * finalize docs. * test debug * fix: assertion values. * debug * debug * debug * fix: repeat * remove prints. * Apply suggestions from code review * Apply suggestions from code review * Correct more * Apply suggestions from code review * Change all * Clean more * fix more * Fix more * Fix more * Correct more * address patrick's comments. * remove unneeded args * clean up pipeline. * sty;e * make the use of additional conditions better conditioned. * None better * dtype * height and width validation * add a note about size brackets. * fix * spit out slow test outputs. * fix? * fix optional test * fix more * remove unneeded comment * debug --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
149 lines
5.4 KiB
Python
149 lines
5.4 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .activations import get_activation
|
|
from .embeddings import CombinedTimestepLabelEmbeddings, CombinedTimestepSizeEmbeddings
|
|
|
|
|
|
class AdaLayerNorm(nn.Module):
|
|
r"""
|
|
Norm layer modified to incorporate timestep embeddings.
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim: int, num_embeddings: int):
|
|
super().__init__()
|
|
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
|
self.silu = nn.SiLU()
|
|
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
|
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
|
|
|
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
|
emb = self.linear(self.silu(self.emb(timestep)))
|
|
scale, shift = torch.chunk(emb, 2)
|
|
x = self.norm(x) * (1 + scale) + shift
|
|
return x
|
|
|
|
|
|
class AdaLayerNormZero(nn.Module):
|
|
r"""
|
|
Norm layer adaptive layer norm zero (adaLN-Zero).
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim: int, num_embeddings: int):
|
|
super().__init__()
|
|
|
|
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
|
|
|
self.silu = nn.SiLU()
|
|
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
|
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
timestep: torch.Tensor,
|
|
class_labels: torch.LongTensor,
|
|
hidden_dtype: Optional[torch.dtype] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
|
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
|
|
|
|
|
class AdaLayerNormSingle(nn.Module):
|
|
r"""
|
|
Norm layer adaptive layer norm single (adaLN-single).
|
|
|
|
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
|
super().__init__()
|
|
|
|
self.emb = CombinedTimestepSizeEmbeddings(
|
|
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
|
)
|
|
|
|
self.silu = nn.SiLU()
|
|
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
|
|
|
def forward(
|
|
self,
|
|
timestep: torch.Tensor,
|
|
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
|
batch_size: int = None,
|
|
hidden_dtype: Optional[torch.dtype] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# No modulation happening here.
|
|
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
|
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
|
|
|
|
|
class AdaGroupNorm(nn.Module):
|
|
r"""
|
|
GroupNorm layer modified to incorporate timestep embeddings.
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
|
num_groups (`int`): The number of groups to separate the channels into.
|
|
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
|
|
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
|
|
"""
|
|
|
|
def __init__(
|
|
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
|
|
):
|
|
super().__init__()
|
|
self.num_groups = num_groups
|
|
self.eps = eps
|
|
|
|
if act_fn is None:
|
|
self.act = None
|
|
else:
|
|
self.act = get_activation(act_fn)
|
|
|
|
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
|
|
|
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
|
if self.act:
|
|
emb = self.act(emb)
|
|
emb = self.linear(emb)
|
|
emb = emb[:, :, None, None]
|
|
scale, shift = emb.chunk(2, dim=1)
|
|
|
|
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
|
x = x * (1 + scale) + shift
|
|
return x
|