4039815276
amused rename Update docs/source/en/api/pipelines/amused.md Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> AdaLayerNormContinuous default values custom micro conditioning micro conditioning docs put lookup from codebook in constructor fix conversion script remove manual fused flash attn kernel add training script temp remove training script add dummy gradient checkpointing func clarify temperatures is an instance variable by setting it remove additional SkipFF block args hardcode norm args rename tests folder fix paths and samples fix tests add training script training readme lora saving and loading non-lora saving/loading some readme fixes guards Update docs/source/en/api/pipelines/amused.md Co-authored-by: Suraj Patil <surajp815@gmail.com> Update examples/amused/README.md Co-authored-by: Suraj Patil <surajp815@gmail.com> Update examples/amused/train_amused.py Co-authored-by: Suraj Patil <surajp815@gmail.com> vae upcasting add fp16 integration tests use tuple for micro cond copyrights remove casts delegate to torch.nn.LayerNorm move temperature to pipeline call upsampling/downsampling changes
255 lines
9.3 KiB
Python
255 lines
9.3 KiB
Python
# coding=utf-8
|
|
# Copyright 2023 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import numbers
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ..utils import is_torch_version
|
|
from .activations import get_activation
|
|
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
|
|
|
|
|
class AdaLayerNorm(nn.Module):
|
|
r"""
|
|
Norm layer modified to incorporate timestep embeddings.
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim: int, num_embeddings: int):
|
|
super().__init__()
|
|
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
|
self.silu = nn.SiLU()
|
|
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
|
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
|
|
|
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
|
emb = self.linear(self.silu(self.emb(timestep)))
|
|
scale, shift = torch.chunk(emb, 2)
|
|
x = self.norm(x) * (1 + scale) + shift
|
|
return x
|
|
|
|
|
|
class AdaLayerNormZero(nn.Module):
|
|
r"""
|
|
Norm layer adaptive layer norm zero (adaLN-Zero).
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim: int, num_embeddings: int):
|
|
super().__init__()
|
|
|
|
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
|
|
|
self.silu = nn.SiLU()
|
|
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
|
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
timestep: torch.Tensor,
|
|
class_labels: torch.LongTensor,
|
|
hidden_dtype: Optional[torch.dtype] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
|
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
|
|
|
|
|
class AdaLayerNormSingle(nn.Module):
|
|
r"""
|
|
Norm layer adaptive layer norm single (adaLN-single).
|
|
|
|
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
|
"""
|
|
|
|
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
|
super().__init__()
|
|
|
|
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
|
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
|
|
)
|
|
|
|
self.silu = nn.SiLU()
|
|
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
|
|
|
def forward(
|
|
self,
|
|
timestep: torch.Tensor,
|
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
batch_size: Optional[int] = None,
|
|
hidden_dtype: Optional[torch.dtype] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# No modulation happening here.
|
|
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
|
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
|
|
|
|
|
class AdaGroupNorm(nn.Module):
|
|
r"""
|
|
GroupNorm layer modified to incorporate timestep embeddings.
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
|
num_groups (`int`): The number of groups to separate the channels into.
|
|
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
|
|
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
|
|
"""
|
|
|
|
def __init__(
|
|
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
|
|
):
|
|
super().__init__()
|
|
self.num_groups = num_groups
|
|
self.eps = eps
|
|
|
|
if act_fn is None:
|
|
self.act = None
|
|
else:
|
|
self.act = get_activation(act_fn)
|
|
|
|
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
|
|
|
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
|
if self.act:
|
|
emb = self.act(emb)
|
|
emb = self.linear(emb)
|
|
emb = emb[:, :, None, None]
|
|
scale, shift = emb.chunk(2, dim=1)
|
|
|
|
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
|
x = x * (1 + scale) + shift
|
|
return x
|
|
|
|
|
|
class AdaLayerNormContinuous(nn.Module):
|
|
def __init__(
|
|
self,
|
|
embedding_dim: int,
|
|
conditioning_embedding_dim: int,
|
|
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
|
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
|
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
|
# However, this is how it was implemented in the original code, and it's rather likely you should
|
|
# set `elementwise_affine` to False.
|
|
elementwise_affine=True,
|
|
eps=1e-5,
|
|
bias=True,
|
|
norm_type="layer_norm",
|
|
):
|
|
super().__init__()
|
|
self.silu = nn.SiLU()
|
|
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
|
if norm_type == "layer_norm":
|
|
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
|
elif norm_type == "rms_norm":
|
|
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
|
else:
|
|
raise ValueError(f"unknown norm_type {norm_type}")
|
|
|
|
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
|
emb = self.linear(self.silu(conditioning_embedding))
|
|
scale, shift = torch.chunk(emb, 2, dim=1)
|
|
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
|
return x
|
|
|
|
|
|
if is_torch_version(">=", "2.1.0"):
|
|
LayerNorm = nn.LayerNorm
|
|
else:
|
|
# Has optional bias parameter compared to torch layer norm
|
|
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
|
class LayerNorm(nn.Module):
|
|
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
|
super().__init__()
|
|
|
|
self.eps = eps
|
|
|
|
if isinstance(dim, numbers.Integral):
|
|
dim = (dim,)
|
|
|
|
self.dim = torch.Size(dim)
|
|
|
|
if elementwise_affine:
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
|
else:
|
|
self.weight = None
|
|
self.bias = None
|
|
|
|
def forward(self, input):
|
|
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
|
super().__init__()
|
|
|
|
self.eps = eps
|
|
|
|
if isinstance(dim, numbers.Integral):
|
|
dim = (dim,)
|
|
|
|
self.dim = torch.Size(dim)
|
|
|
|
if elementwise_affine:
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
else:
|
|
self.weight = None
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
|
|
|
if self.weight is not None:
|
|
# convert into half-precision if necessary
|
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
hidden_states = hidden_states.to(self.weight.dtype)
|
|
hidden_states = hidden_states * self.weight
|
|
else:
|
|
hidden_states = hidden_states.to(input_dtype)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class GlobalResponseNorm(nn.Module):
|
|
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
|
def __init__(self, dim):
|
|
super().__init__()
|
|
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
|
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
|
|
|
def forward(self, x):
|
|
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
|
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
|
return self.gamma * (x * nx) + self.beta + x
|