Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1e3d230904 | |||
| b8f650bac3 | |||
| 6eca4655eb | |||
| 04d18a2669 | |||
| 80d9ddb061 | |||
| 6c825c9497 | |||
| 64c542a0ac | |||
| 04c4d39738 | |||
| 10fe7eeb8a | |||
| 9682a04624 | |||
| 5266ab7935 | |||
| 7f724a930e | |||
| 9bef9f4be7 |
@@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
|
||||
| Example | Description | Code Example | Colab | Author |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|
||||
| HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |
|
||||
| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [](https://huggingface.co/spaces/toshas/marigold) [](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) |
|
||||
| LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) |
|
||||
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
|
||||
@@ -75,6 +76,48 @@ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custo
|
||||
|
||||
## Example usages
|
||||
|
||||
### HD-Painter
|
||||
|
||||
Implementation of [HD-Painter: High-Resolution and Prompt-Faithful Text-Guided Image Inpainting with Diffusion Models](https://arxiv.org/abs/2312.14091).
|
||||
|
||||

|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
Recent progress in text-guided image inpainting, based on the unprecedented success of text-to-image diffusion models, has led to exceptionally realistic and visually plausible results.
|
||||
However, there is still significant potential for improvement in current text-to-image inpainting models, particularly in better aligning the inpainted area with user prompts and performing high-resolution inpainting.
|
||||
Therefore, in this paper we introduce _HD-Painter_, a completely **training-free** approach that **accurately follows to prompts** and coherently **scales to high-resolution** image inpainting.
|
||||
To this end, we design the _Prompt-Aware Introverted Attention (PAIntA)_ layer enhancing self-attention scores by prompt information and resulting in better text alignment generations.
|
||||
To further improve the prompt coherence we introduce the _Reweighting Attention Score Guidance (RASG)_ mechanism seamlessly integrating a post-hoc sampling strategy into general form of DDIM to prevent out-of-distribution latent shifts.
|
||||
Moreover, HD-Painter allows extension to larger scales by introducing a specialized super-resolution technique customized for inpainting, enabling the completion of missing regions in images of up to 2K resolution.
|
||||
Our experiments demonstrate that HD-Painter surpasses existing state-of-the-art approaches qualitatively and quantitatively, achieving an impressive generation accuracy improvement of **61.4** vs **51.9**.
|
||||
We will make the codes publicly available.
|
||||
|
||||
You can find additional information about Text2Video-Zero in the [paper](https://arxiv.org/abs/2312.14091) or the [original codebase](https://github.com/Picsart-AI-Research/HD-Painter).
|
||||
|
||||
#### Usage example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-inpainting",
|
||||
custom_pipeline="hd_painter"
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
prompt = "wooden boat"
|
||||
init_image = load_image("https://raw.githubusercontent.com/Picsart-AI-Research/HD-Painter/main/__assets__/samples/images/2.jpg")
|
||||
mask_image = load_image("https://raw.githubusercontent.com/Picsart-AI-Research/HD-Painter/main/__assets__/samples/masks/2.png")
|
||||
|
||||
image = pipe (prompt, init_image, mask_image, use_rasg = True, use_painta = True, generator=torch.manual_seed(12345)).images[0]
|
||||
|
||||
make_image_grid([init_image, mask_image, image], rows=1, cols=3)
|
||||
|
||||
```
|
||||
|
||||
### Marigold Depth Estimation
|
||||
|
||||
Marigold is a universal monocular depth estimator that delivers accurate and sharp predictions in the wild. Based on Stable Diffusion, it is trained exclusively with synthetic depth data and excels in zero-shot adaptation to real-world imagery. This pipeline is an official implementation of the inference process. More details can be found on our [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) (also implemented with diffusers).
|
||||
|
||||
@@ -0,0 +1,994 @@
|
||||
import math
|
||||
import numbers
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from diffusers.models import AsymmetricAutoencoderKL, ImageProjection
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (
|
||||
StableDiffusionInpaintPipeline,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from diffusers.utils import deprecate
|
||||
|
||||
|
||||
class RASGAttnProcessor:
|
||||
def __init__(self, mask, token_idx, scale_factor):
|
||||
self.attention_scores = None # Stores the last output of the similarity matrix here. Each layer will get its own RASGAttnProcessor assigned
|
||||
self.mask = mask
|
||||
self.token_idx = token_idx
|
||||
self.scale_factor = scale_factor
|
||||
self.mask_resoltuion = mask.shape[-1] * mask.shape[-2] # 64 x 64 if the image is 512x512
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
# Same as the default AttnProcessor up untill the part where similarity matrix gets saved
|
||||
downscale_factor = self.mask_resoltuion // hidden_states.shape[1]
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
# Automatically recognize the resolution and save the attention similarity values
|
||||
# We need to use the values before the softmax function, hence the rewritten get_attention_scores function.
|
||||
if downscale_factor == self.scale_factor**2:
|
||||
self.attention_scores = get_attention_scores(attn, query, key, attention_mask)
|
||||
attention_probs = self.attention_scores.softmax(dim=-1)
|
||||
attention_probs = attention_probs.to(query.dtype)
|
||||
else:
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask) # Original code
|
||||
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PAIntAAttnProcessor:
|
||||
def __init__(self, transformer_block, mask, token_idx, do_classifier_free_guidance, scale_factors):
|
||||
self.transformer_block = transformer_block # Stores the parent transformer block.
|
||||
self.mask = mask
|
||||
self.scale_factors = scale_factors
|
||||
self.do_classifier_free_guidance = do_classifier_free_guidance
|
||||
self.token_idx = token_idx
|
||||
self.shape = mask.shape[2:]
|
||||
self.mask_resoltuion = mask.shape[-1] * mask.shape[-2] # 64 x 64
|
||||
self.default_processor = AttnProcessor()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
# Automatically recognize the resolution of the current attention layer and resize the masks accordingly
|
||||
downscale_factor = self.mask_resoltuion // hidden_states.shape[1]
|
||||
|
||||
mask = None
|
||||
for factor in self.scale_factors:
|
||||
if downscale_factor == factor**2:
|
||||
shape = (self.shape[0] // factor, self.shape[1] // factor)
|
||||
mask = F.interpolate(self.mask, shape, mode="bicubic") # B, 1, H, W
|
||||
break
|
||||
if mask is None:
|
||||
return self.default_processor(attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale)
|
||||
|
||||
# STARTS HERE
|
||||
residual = hidden_states
|
||||
# Save the input hidden_states for later use
|
||||
input_hidden_states = hidden_states
|
||||
|
||||
# ================================================== #
|
||||
# =============== SELF ATTENTION 1 ================= #
|
||||
# ================================================== #
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
# self_attention_probs = attn.get_attention_scores(query, key, attention_mask) # We can't use post-softmax attention scores in this case
|
||||
self_attention_scores = get_attention_scores(
|
||||
attn, query, key, attention_mask
|
||||
) # The custom function returns pre-softmax probabilities
|
||||
self_attention_probs = self_attention_scores.softmax(
|
||||
dim=-1
|
||||
) # Manually compute the probabilities here, the scores will be reused in the second part of PAIntA
|
||||
self_attention_probs = self_attention_probs.to(query.dtype)
|
||||
|
||||
hidden_states = torch.bmm(self_attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
# x = x + self.attn1(self.norm1(x))
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection: # So many residuals everywhere
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
self_attention_output_hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
# ================================================== #
|
||||
# ============ BasicTransformerBlock =============== #
|
||||
# ================================================== #
|
||||
# We use a hack by running the code from the BasicTransformerBlock that is between Self and Cross attentions here
|
||||
# The other option would've been modifying the BasicTransformerBlock and adding this functionality here.
|
||||
# I assumed that changing the BasicTransformerBlock would have been a bigger deal and decided to use this hack isntead.
|
||||
|
||||
# The SelfAttention block recieves the normalized latents from the BasicTransformerBlock,
|
||||
# But the residual of the output is the non-normalized version.
|
||||
# Therefore we unnormalize the input hidden state here
|
||||
unnormalized_input_hidden_states = (
|
||||
input_hidden_states + self.transformer_block.norm1.bias
|
||||
) * self.transformer_block.norm1.weight
|
||||
|
||||
# TODO: return if neccessary
|
||||
# if self.use_ada_layer_norm_zero:
|
||||
# attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
# elif self.use_ada_layer_norm_single:
|
||||
# attn_output = gate_msa * attn_output
|
||||
|
||||
transformer_hidden_states = self_attention_output_hidden_states + unnormalized_input_hidden_states
|
||||
if transformer_hidden_states.ndim == 4:
|
||||
transformer_hidden_states = transformer_hidden_states.squeeze(1)
|
||||
|
||||
# TODO: return if neccessary
|
||||
# 2.5 GLIGEN Control
|
||||
# if gligen_kwargs is not None:
|
||||
# transformer_hidden_states = self.fuser(transformer_hidden_states, gligen_kwargs["objs"])
|
||||
# NOTE: we experimented with using GLIGEN and HDPainter together, the results were not that great
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.transformer_block.use_ada_layer_norm:
|
||||
# transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states, timestep)
|
||||
raise NotImplementedError()
|
||||
elif self.transformer_block.use_ada_layer_norm_zero or self.transformer_block.use_layer_norm:
|
||||
transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states)
|
||||
elif self.transformer_block.use_ada_layer_norm_single:
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
transformer_norm_hidden_states = transformer_hidden_states
|
||||
elif self.transformer_block.use_ada_layer_norm_continuous:
|
||||
# transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
|
||||
if self.transformer_block.pos_embed is not None and self.transformer_block.use_ada_layer_norm_single is False:
|
||||
transformer_norm_hidden_states = self.transformer_block.pos_embed(transformer_norm_hidden_states)
|
||||
|
||||
# ================================================== #
|
||||
# ================= CROSS ATTENTION ================ #
|
||||
# ================================================== #
|
||||
|
||||
# We do an initial pass of the CrossAttention up to obtaining the similarity matrix here.
|
||||
# The similarity matrix is used to obtain scaling coefficients for the attention matrix of the self attention
|
||||
# We reuse the previously computed self-attention matrix, and only repeat the steps after the softmax
|
||||
|
||||
cross_attention_input_hidden_states = (
|
||||
transformer_norm_hidden_states # Renaming the variable for the sake of readability
|
||||
)
|
||||
|
||||
# TODO: check if classifier_free_guidance is being used before splitting here
|
||||
if self.do_classifier_free_guidance:
|
||||
# Our scaling coefficients depend only on the conditional part, so we split the inputs
|
||||
(
|
||||
_cross_attention_input_hidden_states_unconditional,
|
||||
cross_attention_input_hidden_states_conditional,
|
||||
) = cross_attention_input_hidden_states.chunk(2)
|
||||
|
||||
# Same split for the encoder_hidden_states i.e. the tokens
|
||||
# Since the SelfAttention processors don't get the encoder states as input, we inject them into the processor in the begining.
|
||||
_encoder_hidden_states_unconditional, encoder_hidden_states_conditional = self.encoder_hidden_states.chunk(
|
||||
2
|
||||
)
|
||||
else:
|
||||
cross_attention_input_hidden_states_conditional = cross_attention_input_hidden_states
|
||||
encoder_hidden_states_conditional = self.encoder_hidden_states.chunk(2)
|
||||
|
||||
# Rename the variables for the sake of readability
|
||||
# The part below is the beginning of the __call__ function of the following CrossAttention layer
|
||||
cross_attention_hidden_states = cross_attention_input_hidden_states_conditional
|
||||
cross_attention_encoder_hidden_states = encoder_hidden_states_conditional
|
||||
|
||||
attn2 = self.transformer_block.attn2
|
||||
|
||||
if attn2.spatial_norm is not None:
|
||||
cross_attention_hidden_states = attn2.spatial_norm(cross_attention_hidden_states, temb)
|
||||
|
||||
input_ndim = cross_attention_hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = cross_attention_hidden_states.shape
|
||||
cross_attention_hidden_states = cross_attention_hidden_states.view(
|
||||
batch_size, channel, height * width
|
||||
).transpose(1, 2)
|
||||
|
||||
(
|
||||
batch_size,
|
||||
sequence_length,
|
||||
_,
|
||||
) = cross_attention_hidden_states.shape # It is definitely a cross attention, so no need for an if block
|
||||
# TODO: change the attention_mask here
|
||||
attention_mask = attn2.prepare_attention_mask(
|
||||
None, sequence_length, batch_size
|
||||
) # I assume the attention mask is the same...
|
||||
|
||||
if attn2.group_norm is not None:
|
||||
cross_attention_hidden_states = attn2.group_norm(cross_attention_hidden_states.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
)
|
||||
|
||||
query2 = attn2.to_q(cross_attention_hidden_states)
|
||||
|
||||
if attn2.norm_cross:
|
||||
cross_attention_encoder_hidden_states = attn2.norm_encoder_hidden_states(
|
||||
cross_attention_encoder_hidden_states
|
||||
)
|
||||
|
||||
key2 = attn2.to_k(cross_attention_encoder_hidden_states)
|
||||
query2 = attn2.head_to_batch_dim(query2)
|
||||
key2 = attn2.head_to_batch_dim(key2)
|
||||
|
||||
cross_attention_probs = attn2.get_attention_scores(query2, key2, attention_mask)
|
||||
|
||||
# CrossAttention ends here, the remaining part is not used
|
||||
|
||||
# ================================================== #
|
||||
# ================ SELF ATTENTION 2 ================ #
|
||||
# ================================================== #
|
||||
# DEJA VU!
|
||||
|
||||
mask = (mask > 0.5).to(self_attention_output_hidden_states.dtype)
|
||||
m = mask.to(self_attention_output_hidden_states.device)
|
||||
# m = rearrange(m, 'b c h w -> b (h w) c').contiguous()
|
||||
m = m.permute(0, 2, 3, 1).reshape((m.shape[0], -1, m.shape[1])).contiguous() # B HW 1
|
||||
m = torch.matmul(m, m.permute(0, 2, 1)) + (1 - m)
|
||||
|
||||
# # Compute scaling coefficients for the similarity matrix
|
||||
# # Select the cross attention values for the correct tokens only!
|
||||
# cross_attention_probs = cross_attention_probs.mean(dim = 0)
|
||||
# cross_attention_probs = cross_attention_probs[:, self.token_idx].sum(dim=1)
|
||||
|
||||
# cross_attention_probs = cross_attention_probs.reshape(shape)
|
||||
# gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).to(self_attention_output_hidden_states.device)
|
||||
# cross_attention_probs = gaussian_smoothing(cross_attention_probs.unsqueeze(0))[0] # optional smoothing
|
||||
# cross_attention_probs = cross_attention_probs.reshape(-1)
|
||||
# cross_attention_probs = ((cross_attention_probs - torch.median(cross_attention_probs.ravel())) / torch.max(cross_attention_probs.ravel())).clip(0, 1)
|
||||
|
||||
# c = (1 - m) * cross_attention_probs.reshape(1, 1, -1) + m # PAIntA scaling coefficients
|
||||
|
||||
# Compute scaling coefficients for the similarity matrix
|
||||
# Select the cross attention values for the correct tokens only!
|
||||
|
||||
batch_size, dims, channels = cross_attention_probs.shape
|
||||
batch_size = batch_size // attn.heads
|
||||
cross_attention_probs = cross_attention_probs.reshape((batch_size, attn.heads, dims, channels)) # B, D, HW, T
|
||||
|
||||
cross_attention_probs = cross_attention_probs.mean(dim=1) # B, HW, T
|
||||
cross_attention_probs = cross_attention_probs[..., self.token_idx].sum(dim=-1) # B, HW
|
||||
cross_attention_probs = cross_attention_probs.reshape((batch_size,) + shape) # , B, H, W
|
||||
|
||||
gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).to(
|
||||
self_attention_output_hidden_states.device
|
||||
)
|
||||
cross_attention_probs = gaussian_smoothing(cross_attention_probs[:, None])[:, 0] # optional smoothing B, H, W
|
||||
|
||||
# Median normalization
|
||||
cross_attention_probs = cross_attention_probs.reshape(batch_size, -1) # B, HW
|
||||
cross_attention_probs = (
|
||||
cross_attention_probs - cross_attention_probs.median(dim=-1, keepdim=True).values
|
||||
) / cross_attention_probs.max(dim=-1, keepdim=True).values
|
||||
cross_attention_probs = cross_attention_probs.clip(0, 1)
|
||||
|
||||
c = (1 - m) * cross_attention_probs.reshape(batch_size, 1, -1) + m
|
||||
c = c.repeat_interleave(attn.heads, 0) # BD, HW
|
||||
if self.do_classifier_free_guidance:
|
||||
c = torch.cat([c, c]) # 2BD, HW
|
||||
|
||||
# Rescaling the original self-attention matrix
|
||||
self_attention_scores_rescaled = self_attention_scores * c
|
||||
self_attention_probs_rescaled = self_attention_scores_rescaled.softmax(dim=-1)
|
||||
|
||||
# Continuing the self attention normally using the new matrix
|
||||
hidden_states = torch.bmm(self_attention_probs_rescaled, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + input_hidden_states
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline):
|
||||
def get_tokenized_prompt(self, prompt):
|
||||
out = self.tokenizer(prompt)
|
||||
return [self.tokenizer.decode(x) for x in out["input_ids"]]
|
||||
|
||||
def init_attn_processors(
|
||||
self,
|
||||
mask,
|
||||
token_idx,
|
||||
use_painta=True,
|
||||
use_rasg=True,
|
||||
painta_scale_factors=[2, 4], # 64x64 -> [16x16, 32x32]
|
||||
rasg_scale_factor=4, # 64x64 -> 16x16
|
||||
self_attention_layer_name="attn1",
|
||||
cross_attention_layer_name="attn2",
|
||||
list_of_painta_layer_names=None,
|
||||
list_of_rasg_layer_names=None,
|
||||
):
|
||||
default_processor = AttnProcessor()
|
||||
width, height = mask.shape[-2:]
|
||||
width, height = width // self.vae_scale_factor, height // self.vae_scale_factor
|
||||
|
||||
painta_scale_factors = [x * self.vae_scale_factor for x in painta_scale_factors]
|
||||
rasg_scale_factor = self.vae_scale_factor * rasg_scale_factor
|
||||
|
||||
attn_processors = {}
|
||||
for x in self.unet.attn_processors:
|
||||
if (list_of_painta_layer_names is None and self_attention_layer_name in x) or (
|
||||
list_of_painta_layer_names is not None and x in list_of_painta_layer_names
|
||||
):
|
||||
if use_painta:
|
||||
transformer_block = self.unet.get_submodule(x.replace(".attn1.processor", ""))
|
||||
attn_processors[x] = PAIntAAttnProcessor(
|
||||
transformer_block, mask, token_idx, self.do_classifier_free_guidance, painta_scale_factors
|
||||
)
|
||||
else:
|
||||
attn_processors[x] = default_processor
|
||||
elif (list_of_rasg_layer_names is None and cross_attention_layer_name in x) or (
|
||||
list_of_rasg_layer_names is not None and x in list_of_rasg_layer_names
|
||||
):
|
||||
if use_rasg:
|
||||
attn_processors[x] = RASGAttnProcessor(mask, token_idx, rasg_scale_factor)
|
||||
else:
|
||||
attn_processors[x] = default_processor
|
||||
|
||||
self.unet.set_attn_processor(attn_processors)
|
||||
# import json
|
||||
# with open('/home/hayk.manukyan/repos/diffusers/debug.txt', 'a') as f:
|
||||
# json.dump({x:str(y) for x,y in self.unet.attn_processors.items()}, f, indent=4)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
masked_image_latents: torch.FloatTensor = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
strength: float = 1.0,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
positive_prompt: Optional[str] = "",
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.01,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: int = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
use_painta=True,
|
||||
use_rasg=True,
|
||||
self_attention_layer_name=".attn1",
|
||||
cross_attention_layer_name=".attn2",
|
||||
painta_scale_factors=[2, 4], # 16 x 16 and 32 x 32
|
||||
rasg_scale_factor=4, # 16x16 by default
|
||||
list_of_painta_layer_names=None,
|
||||
list_of_rasg_layer_names=None,
|
||||
**kwargs,
|
||||
):
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
||||
)
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
#
|
||||
prompt_no_positives = prompt
|
||||
if isinstance(prompt, list):
|
||||
prompt = [x + positive_prompt for x in prompt]
|
||||
else:
|
||||
prompt = prompt + positive_prompt
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# assert batch_size == 1, "Does not work with batch size > 1 currently"
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. set timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
num_inference_steps=num_inference_steps, strength=strength, device=device
|
||||
)
|
||||
# check that number of inference steps is not < 1 - as this doesn't make sense
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError(
|
||||
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
||||
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
||||
)
|
||||
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
||||
is_strength_max = strength == 1.0
|
||||
|
||||
# 5. Preprocess mask and image
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
original_image = image
|
||||
init_image = self.image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
num_channels_unet = self.unet.config.in_channels
|
||||
return_image_latents = num_channels_unet == 4
|
||||
|
||||
latents_outputs = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
image=init_image,
|
||||
timestep=latent_timestep,
|
||||
is_strength_max=is_strength_max,
|
||||
return_noise=True,
|
||||
return_image_latents=return_image_latents,
|
||||
)
|
||||
|
||||
if return_image_latents:
|
||||
latents, noise, image_latents = latents_outputs
|
||||
else:
|
||||
latents, noise = latents_outputs
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
mask_condition = self.mask_processor.preprocess(
|
||||
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
||||
)
|
||||
|
||||
if masked_image_latents is None:
|
||||
masked_image = init_image * (mask_condition < 0.5)
|
||||
else:
|
||||
masked_image = masked_image_latents
|
||||
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask_condition,
|
||||
masked_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 7.5 Setting up HD-Painter
|
||||
|
||||
# Get the indices of the tokens to be modified by both RASG and PAIntA
|
||||
token_idx = list(range(1, self.get_tokenized_prompt(prompt_no_positives).index("<|endoftext|>"))) + [
|
||||
self.get_tokenized_prompt(prompt).index("<|endoftext|>")
|
||||
]
|
||||
|
||||
# Setting up the attention processors
|
||||
self.init_attn_processors(
|
||||
mask_condition,
|
||||
token_idx,
|
||||
use_painta,
|
||||
use_rasg,
|
||||
painta_scale_factors=painta_scale_factors,
|
||||
rasg_scale_factor=rasg_scale_factor,
|
||||
self_attention_layer_name=self_attention_layer_name,
|
||||
cross_attention_layer_name=cross_attention_layer_name,
|
||||
list_of_painta_layer_names=list_of_painta_layer_names,
|
||||
list_of_rasg_layer_names=list_of_rasg_layer_names,
|
||||
)
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
raise ValueError(
|
||||
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
|
||||
)
|
||||
|
||||
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
if use_rasg:
|
||||
extra_step_kwargs["generator"] = None
|
||||
|
||||
# 9.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
# 9.2 Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
if self.unet.config.time_cond_proj_dim is not None:
|
||||
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
||||
timestep_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
|
||||
# 10. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
painta_active = True
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if t < 500 and painta_active:
|
||||
self.init_attn_processors(
|
||||
mask_condition,
|
||||
token_idx,
|
||||
False,
|
||||
use_rasg,
|
||||
painta_scale_factors=painta_scale_factors,
|
||||
rasg_scale_factor=rasg_scale_factor,
|
||||
self_attention_layer_name=self_attention_layer_name,
|
||||
cross_attention_layer_name=cross_attention_layer_name,
|
||||
list_of_painta_layer_names=list_of_painta_layer_names,
|
||||
list_of_rasg_layer_names=list_of_rasg_layer_names,
|
||||
)
|
||||
painta_active = False
|
||||
|
||||
with torch.enable_grad():
|
||||
self.unet.zero_grad()
|
||||
latents = latents.detach()
|
||||
latents.requires_grad = True
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
if num_channels_unet == 9:
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
self.scheduler.latents = latents
|
||||
self.encoder_hidden_states = prompt_embeds
|
||||
for attn_processor in self.unet.attn_processors.values():
|
||||
attn_processor.encoder_hidden_states = prompt_embeds
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if use_rasg:
|
||||
# Perform RASG
|
||||
_, _, height, width = mask_condition.shape # 512 x 512
|
||||
scale_factor = self.vae_scale_factor * rasg_scale_factor # 8 * 4 = 32
|
||||
|
||||
# TODO: Fix for > 1 batch_size
|
||||
rasg_mask = F.interpolate(
|
||||
mask_condition, (height // scale_factor, width // scale_factor), mode="bicubic"
|
||||
)[0, 0] # mode is nearest by default, B, H, W
|
||||
|
||||
# Aggregate the saved attention maps
|
||||
attn_map = []
|
||||
for processor in self.unet.attn_processors.values():
|
||||
if hasattr(processor, "attention_scores") and processor.attention_scores is not None:
|
||||
if self.do_classifier_free_guidance:
|
||||
attn_map.append(processor.attention_scores.chunk(2)[1]) # (B/2) x H, 256, 77
|
||||
else:
|
||||
attn_map.append(processor.attention_scores) # B x H, 256, 77 ?
|
||||
|
||||
attn_map = (
|
||||
torch.cat(attn_map)
|
||||
.mean(0)
|
||||
.permute(1, 0)
|
||||
.reshape((-1, height // scale_factor, width // scale_factor))
|
||||
) # 77, 16, 16
|
||||
|
||||
# Compute the attention score
|
||||
attn_score = -sum(
|
||||
[
|
||||
F.binary_cross_entropy_with_logits(x - 1.0, rasg_mask.to(device))
|
||||
for x in attn_map[token_idx]
|
||||
]
|
||||
)
|
||||
|
||||
# Backward the score and compute the gradients
|
||||
attn_score.backward()
|
||||
|
||||
# Normalzie the gradients and compute the noise component
|
||||
variance_noise = latents.grad.detach()
|
||||
# print("VARIANCE SHAPE", variance_noise.shape)
|
||||
variance_noise -= torch.mean(variance_noise, [1, 2, 3], keepdim=True)
|
||||
variance_noise /= torch.std(variance_noise, [1, 2, 3], keepdim=True)
|
||||
else:
|
||||
variance_noise = None
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs, return_dict=False, variance_noise=variance_noise
|
||||
)[0]
|
||||
|
||||
if num_channels_unet == 4:
|
||||
init_latents_proper = image_latents
|
||||
if self.do_classifier_free_guidance:
|
||||
init_mask, _ = mask.chunk(2)
|
||||
else:
|
||||
init_mask = mask
|
||||
|
||||
if i < len(timesteps) - 1:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
init_latents_proper, noise, torch.tensor([noise_timestep])
|
||||
)
|
||||
|
||||
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
mask = callback_outputs.pop("mask", mask)
|
||||
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
condition_kwargs = {}
|
||||
if isinstance(self.vae, AsymmetricAutoencoderKL):
|
||||
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
|
||||
init_image_condition = init_image.clone()
|
||||
init_image = self._encode_vae_image(init_image, generator=generator)
|
||||
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
|
||||
condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
|
||||
image = self.vae.decode(
|
||||
latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
|
||||
)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
# ============= Utility Functions ============== #
|
||||
|
||||
|
||||
class GaussianSmoothing(nn.Module):
|
||||
"""
|
||||
Apply gaussian smoothing on a
|
||||
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
||||
in the input using a depthwise convolution.
|
||||
Arguments:
|
||||
channels (int, sequence): Number of channels of the input tensors. Output will
|
||||
have this number of channels as well.
|
||||
kernel_size (int, sequence): Size of the gaussian kernel.
|
||||
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
||||
dim (int, optional): The number of dimensions of the data.
|
||||
Default value is 2 (spatial).
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, sigma, dim=2):
|
||||
super(GaussianSmoothing, self).__init__()
|
||||
if isinstance(kernel_size, numbers.Number):
|
||||
kernel_size = [kernel_size] * dim
|
||||
if isinstance(sigma, numbers.Number):
|
||||
sigma = [sigma] * dim
|
||||
|
||||
# The gaussian kernel is the product of the
|
||||
# gaussian function of each dimension.
|
||||
kernel = 1
|
||||
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
|
||||
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||
mean = (size - 1) / 2
|
||||
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
|
||||
|
||||
# Make sure sum of values in gaussian kernel equals 1.
|
||||
kernel = kernel / torch.sum(kernel)
|
||||
|
||||
# Reshape to depthwise convolutional weight
|
||||
kernel = kernel.view(1, 1, *kernel.size())
|
||||
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||
|
||||
self.register_buffer("weight", kernel)
|
||||
self.groups = channels
|
||||
|
||||
if dim == 1:
|
||||
self.conv = F.conv1d
|
||||
elif dim == 2:
|
||||
self.conv = F.conv2d
|
||||
elif dim == 3:
|
||||
self.conv = F.conv3d
|
||||
else:
|
||||
raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim))
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Apply gaussian filter to input.
|
||||
Arguments:
|
||||
input (torch.Tensor): Input to apply gaussian filter on.
|
||||
Returns:
|
||||
filtered (torch.Tensor): Filtered output.
|
||||
"""
|
||||
return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding="same")
|
||||
|
||||
|
||||
def get_attention_scores(
|
||||
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Compute the attention scores.
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`): The query tensor.
|
||||
key (`torch.Tensor`): The key tensor.
|
||||
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The attention probabilities/scores.
|
||||
"""
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
if attention_mask is None:
|
||||
baddbmm_input = torch.empty(
|
||||
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
beta = 0
|
||||
else:
|
||||
baddbmm_input = attention_mask
|
||||
beta = 1
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
baddbmm_input,
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=beta,
|
||||
alpha=self.scale,
|
||||
)
|
||||
del baddbmm_input
|
||||
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
|
||||
return attention_scores
|
||||
@@ -0,0 +1,172 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
|
||||
from diffusers import Transformer3DModel
|
||||
|
||||
|
||||
ckpt_id = "PixArt-alpha/PixArt-alpha"
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/scripts/inference.py#L125
|
||||
interpolation_scale = {256: 0.5, 512: 1}
|
||||
|
||||
|
||||
def main(args):
|
||||
state_dict = {}
|
||||
with safe_open(args.orig_ckpt_path, framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
converted_state_dict = {}
|
||||
|
||||
# Patch embeddings.
|
||||
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Caption projection.
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
# AdaLN-single LN
|
||||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Shared norm.
|
||||
converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
|
||||
converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")
|
||||
|
||||
for depth in range(28):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
f"blocks.{depth}.scale_shift_table"
|
||||
)
|
||||
|
||||
# Attention is all you need 🤘
|
||||
|
||||
# Self attention.
|
||||
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Temporal attention.
|
||||
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn_temp.qkv.weight"), 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn_temp.qkv.bias"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_v.bias"] = v_bias
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn_temp.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn_temporal.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn_temp.proj.bias"
|
||||
)
|
||||
|
||||
# Feed-forward.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.fc2.bias"
|
||||
)
|
||||
|
||||
# Cross-attention.
|
||||
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
|
||||
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
|
||||
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# Final block.
|
||||
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
|
||||
converted_state_dict["pos_embed_temporal"] = state_dict.pop("pos_embed_temporal")
|
||||
|
||||
# DiT XL/2
|
||||
transformer = Transformer3DModel(
|
||||
sample_size=(16, args.image_size // 8, args.image_size // 8),
|
||||
patch_size=(1, 2, 2),
|
||||
num_layers=28,
|
||||
attention_head_dim=72,
|
||||
num_attention_heads=16,
|
||||
in_channels=4,
|
||||
out_channels=8,
|
||||
cross_attention_dim=1152,
|
||||
num_embeds_ada_norm=1000,
|
||||
norm_eps=1e-6,
|
||||
caption_channels=4096,
|
||||
)
|
||||
transformer.load_state_dict(converted_state_dict, strict=True)
|
||||
|
||||
assert transformer.pos_embed.pos_embed is not None
|
||||
state_dict.pop("pos_embed")
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
|
||||
|
||||
num_model_params = sum(p.numel() for p in transformer.parameters())
|
||||
print(f"Total number of transformer parameters: {num_model_params}")
|
||||
|
||||
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
default=256,
|
||||
type=int,
|
||||
choices=[256, 512],
|
||||
required=False,
|
||||
help="Image size of pretrained model, either 256 or 512.",
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -90,6 +90,7 @@ else:
|
||||
"T2IAdapter",
|
||||
"T5FilmDecoder",
|
||||
"Transformer2DModel",
|
||||
"Transformer3DModel",
|
||||
"UNet1DModel",
|
||||
"UNet2DConditionModel",
|
||||
"UNet2DModel",
|
||||
@@ -256,6 +257,7 @@ else:
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"MusicLDMPipeline",
|
||||
"OpenSoraPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"PIAPipeline",
|
||||
"PixArtAlphaPipeline",
|
||||
@@ -483,6 +485,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
T2IAdapter,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
Transformer3DModel,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
@@ -628,6 +631,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
MusicLDMPipeline,
|
||||
OpenSoraPipeline,
|
||||
PaintByExamplePipeline,
|
||||
PIAPipeline,
|
||||
PixArtAlphaPipeline,
|
||||
|
||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_3d"] = ["Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
@@ -75,6 +76,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PriorTransformer,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
Transformer3DModel,
|
||||
TransformerTemporalModel,
|
||||
)
|
||||
from .unets import (
|
||||
|
||||
@@ -187,6 +187,75 @@ class PatchEmbed(nn.Module):
|
||||
return (latent + pos_embed).to(latent.dtype)
|
||||
|
||||
|
||||
class PatchEmbed3D(nn.Module):
|
||||
"""Video to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
height=224,
|
||||
width=224,
|
||||
patch_size=(1, 2, 2),
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
layer_norm=False,
|
||||
bias=True,
|
||||
interpolation_scale=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_patches = (height // patch_size[1]) * (width // patch_size[2])
|
||||
self.layer_norm = layer_norm
|
||||
self.emed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
if layer_norm:
|
||||
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
self.patch_size = patch_size
|
||||
# See:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
|
||||
self.height, self.width = height // patch_size[1], width // patch_size[2]
|
||||
self.base_size = height // patch_size[1]
|
||||
self.interpolation_scale = interpolation_scale
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
||||
)
|
||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
||||
|
||||
def forward(self, latent):
|
||||
height, width = latent.shape[-2] // self.patch_size[1], latent.shape[-1] // self.patch_size[2]
|
||||
|
||||
latent = self.proj(latent) # (B C T H W)
|
||||
|
||||
if self.layer_norm:
|
||||
batch_size, _, num_frames, height, width = latent.size()
|
||||
latent = latent.flatten(2).transpose(1, 2)
|
||||
latent = self.norm(latent)
|
||||
latent = latent.transpose(1, 2).view(batch_size, self.emed_dim, num_frames, height, width)
|
||||
|
||||
latent = latent.flatten(3).permute(0, 2, 3, 1) # BCTHW -> BT(HW)C
|
||||
|
||||
# Interpolate positional embeddings if needed.
|
||||
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
|
||||
if self.height != height or self.width != width:
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
embed_dim=self.pos_embed.shape[-1],
|
||||
grid_size=(height, width),
|
||||
base_size=self.base_size,
|
||||
interpolation_scale=self.interpolation_scale,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed)
|
||||
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
|
||||
latent = (latent + pos_embed).to(latent.dtype)
|
||||
latent = latent.flatten(1, 2) # BT(H*W)C -> B(T*H*W)C
|
||||
return latent
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -6,4 +6,5 @@ if is_torch_available():
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_3d import Transformer3DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
@@ -0,0 +1,434 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward, _chunked_feed_forward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import PatchEmbed3D, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer3DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Transformer3DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
|
||||
pos = np.arange(0, length)[..., None] / scale
|
||||
return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class Transformer3DBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
cross_attention_dim: int,
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
norm_eps: float = 1e-6,
|
||||
num_temporal_patches: int = 16,
|
||||
num_spatial_patches: int = 256,
|
||||
):
|
||||
super().__init__()
|
||||
# We keep these boolean flags for backward-compatibility.
|
||||
|
||||
self.use_ada_layer_norm_single = True
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
self.num_temporal_patches = num_temporal_patches
|
||||
self.num_spatial_patches = num_spatial_patches
|
||||
|
||||
# Define 3 blocks. Each block has its own normalization layer.
|
||||
# 1. Spatial Self-Attn
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# 2. Temporal Self-Attn
|
||||
self.attn_temporal = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, elementwise_affine=False)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = FeedForward(dim, activation_fn="gelu-approximate")
|
||||
|
||||
# 4. Scale-shift for PixArt-Alpha.
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temporal_pos_embed: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
|
||||
# 1. Spatial Self-Attention
|
||||
# reshape (batch, num_temporal_patches*num_spatial_patches, dim) -> (batch * num_temporal_patches, num_spatial_patches, dim)
|
||||
norm_hidden_states = norm_hidden_states.view(
|
||||
batch_size, self.num_temporal_patches, self.num_spatial_patches, -1
|
||||
)
|
||||
norm_hidden_states = norm_hidden_states.view(
|
||||
batch_size * self.num_temporal_patches, self.num_spatial_patches, -1
|
||||
)
|
||||
|
||||
attn_output = self.attn1(norm_hidden_states)
|
||||
|
||||
# reshape (batch * num_temporal_patches, num_spatial_patches, dim) -> (batch, num_temporal_patches*num_spatial_patches, dim)
|
||||
attn_output = attn_output.view(batch_size, self.num_temporal_patches, self.num_spatial_patches, -1)
|
||||
attn_output = attn_output.view(batch_size, self.num_temporal_patches * self.num_spatial_patches, -1)
|
||||
|
||||
attn_output = gate_msa * attn_output
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 2. Temporal Self-Attention
|
||||
# reshape (batch, num_temporal_patches*num_spatial_patches, dim) -> (batch * num_spatial_patches, num_temporal_patches, dim)
|
||||
temporal_hidden_states = (
|
||||
hidden_states.view(batch_size, self.num_temporal_patches, self.num_spatial_patches, -1)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
temporal_hidden_states = temporal_hidden_states.view(
|
||||
batch_size * self.num_spatial_patches, self.num_temporal_patches, -1
|
||||
)
|
||||
|
||||
if temporal_pos_embed is not None:
|
||||
temporal_hidden_states = temporal_hidden_states + temporal_pos_embed
|
||||
|
||||
attn_output = self.attn_temporal(temporal_hidden_states)
|
||||
|
||||
# reshape (batch * num_spatial_patches, num_temporal_patches, dim) -> (batch, num_temporal_patches*num_spatial_patches, dim)
|
||||
attn_output = (
|
||||
attn_output.view(batch_size, self.num_spatial_patches, self.num_temporal_patches, -1)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
attn_output = attn_output.view(batch_size, self.num_temporal_patches * self.num_spatial_patches, -1)
|
||||
|
||||
attn_output = gate_msa * attn_output
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 3. Cross-Attention
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
norm_hidden_states = hidden_states
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
ff_output = gate_mlp * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A 3D Transformer model for image-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states.
|
||||
|
||||
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
sample_size: Tuple[int] = (2, 4, 4),
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 8,
|
||||
num_layers: int = 1,
|
||||
cross_attention_dim: int = 256,
|
||||
num_embeds_ada_norm: int = 1000,
|
||||
norm_eps: float = 1e-6,
|
||||
caption_channels: int = 256,
|
||||
interpolation_scale: float = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 2. Define input layers
|
||||
self.height = sample_size[1]
|
||||
self.width = sample_size[2]
|
||||
self.num_patches = np.prod([sample_size[i] // patch_size[i] for i in range(3)])
|
||||
self.num_temporal_patches = sample_size[0] // patch_size[0]
|
||||
self.num_spatial_patches = self.num_patches // self.num_temporal_patches
|
||||
|
||||
self.patch_size = patch_size
|
||||
interpolation_scale = (
|
||||
interpolation_scale if interpolation_scale is not None else max(self.config.sample_size[1] // 64, 1)
|
||||
)
|
||||
self.pos_embed = PatchEmbed3D(
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
Transformer3DBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
norm_eps=norm_eps,
|
||||
num_temporal_patches=self.num_temporal_patches,
|
||||
num_spatial_patches=self.num_spatial_patches,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(inner_dim, np.prod(patch_size) * self.out_channels)
|
||||
|
||||
# 5. PixArt-Alpha blocks.
|
||||
self.adaln_single = AdaLayerNormSingle(inner_dim)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
interpolation_scale = max(self.config.sample_size[0] // 16, 1)
|
||||
temporal_pos_embed = get_1d_sincos_pos_embed(inner_dim, self.num_temporal_patches, scale=interpolation_scale)
|
||||
temporal_pos_embed = torch.from_numpy(temporal_pos_embed).float().unsqueeze(0).requires_grad_(False)
|
||||
self.register_buffer("pos_embed_temporal", temporal_pos_embed)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
height, width = hidden_states.shape[-2] // self.patch_size[1], hidden_states.shape[-1] // self.patch_size[2]
|
||||
# import ipdb; ipdb.set_trace()
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
batch_size = hidden_states.shape[0]
|
||||
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
temporal_pos_embed=self.pos_embed_temporal if i == 0 else None,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(
|
||||
-1,
|
||||
self.num_temporal_patches,
|
||||
height,
|
||||
width,
|
||||
self.patch_size[1],
|
||||
self.patch_size[2],
|
||||
self.out_channels,
|
||||
)
|
||||
)
|
||||
hidden_states = torch.einsum("nthwpqc->ncthpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(
|
||||
-1,
|
||||
self.out_channels,
|
||||
self.num_temporal_patches,
|
||||
height * self.patch_size[1],
|
||||
width * self.patch_size[2],
|
||||
)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer3DModelOutput(sample=output)
|
||||
@@ -311,10 +311,10 @@ class TransformerSpatioTemporalModel(nn.Module):
|
||||
time_context_first_timestep = time_context[None, :].reshape(
|
||||
batch_size, num_frames, -1, time_context.shape[-1]
|
||||
)[:, 0]
|
||||
time_context = time_context_first_timestep[None, :].broadcast_to(
|
||||
height * width, batch_size, 1, time_context.shape[-1]
|
||||
time_context = time_context_first_timestep[:, None].broadcast_to(
|
||||
batch_size, height * width, time_context.shape[-2], time_context.shape[-1]
|
||||
)
|
||||
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
||||
time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1])
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
@@ -247,6 +247,7 @@ else:
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
_import_structure["open_sora"] = ["OpenSoraPipeline"]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -438,6 +439,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .open_sora import OpenSoraPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_open_sora"] = ["OpenSoraPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_open_sora import OpenSoraPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,912 @@
|
||||
# Copyright 2024 Open-Sora Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL, Transformer3DModel
|
||||
from ...schedulers import DPMSolverMultistepScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_bs4_available():
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import OpenSoraPipeline
|
||||
|
||||
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
|
||||
>>> pipe = OpenSoraPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
|
||||
>>> # Enable memory optimizations.
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> prompt = "A small cactus with a happy face in the Sahara desert."
|
||||
>>> image = pipe(prompt).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenSoraPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Stable Video Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.FloatTensor`]):
|
||||
List of denoised PIL images of length `batch_size` or numpy array or torch tensor
|
||||
of shape `(batch_size, num_frames, height, width, num_channels)`.
|
||||
"""
|
||||
|
||||
frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.FloatTensor]
|
||||
|
||||
|
||||
def _append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class OpenSoraPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Open-Sora.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Open-Sora uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/Open-Sora/Open-Sora/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`Transformer2DModel`]):
|
||||
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
bad_punct_regex = re.compile(
|
||||
r"["
|
||||
+ "#®•©™&@·º½¾¿¡§~"
|
||||
+ r"\)"
|
||||
+ r"\("
|
||||
+ r"\]"
|
||||
+ r"\["
|
||||
+ r"\}"
|
||||
+ r"\{"
|
||||
+ r"\|"
|
||||
+ "\\"
|
||||
+ r"\/"
|
||||
+ r"\*"
|
||||
+ r"]{1,}"
|
||||
) # noqa
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder"]
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKL,
|
||||
transformer: Transformer3DModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Adapted from https://github.com/Open-Sora/Open-Sora/blob/master/diffusion/model/utils.py
|
||||
def mask_text_embeddings(self, emb, mask):
|
||||
if emb.shape[0] == 1:
|
||||
keep_index = mask.sum().item()
|
||||
return emb[:, :, :keep_index, :], keep_index
|
||||
else:
|
||||
masked_feature = emb * mask[:, None, :, None]
|
||||
return masked_feature, emb.shape[2]
|
||||
|
||||
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 120,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
Open-Sora, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Open-Sora, it's should be the embeddings of the ""
|
||||
string.
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
|
||||
"""
|
||||
|
||||
if "mask_feature" in kwargs:
|
||||
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
|
||||
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# See Section 3.1. of the paper.
|
||||
max_length = max_sequence_length
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
elif self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens = [negative_prompt] * batch_size
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
negative_prompt_attention_mask = uncond_input.attention_mask
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_attention_mask = None
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_steps,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if clean_caption and not is_ftfy_available():
|
||||
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
|
||||
logger.warning("Setting `clean_caption` to False...")
|
||||
clean_caption = False
|
||||
|
||||
if not isinstance(text, (tuple, list)):
|
||||
text = [text]
|
||||
|
||||
def process(text: str):
|
||||
if clean_caption:
|
||||
text = self._clean_caption(text)
|
||||
text = self._clean_caption(text)
|
||||
else:
|
||||
text = text.lower().strip()
|
||||
return text
|
||||
|
||||
return [process(t) for t in text]
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
|
||||
def _clean_caption(self, caption):
|
||||
caption = str(caption)
|
||||
caption = ul.unquote_plus(caption)
|
||||
caption = caption.strip().lower()
|
||||
caption = re.sub("<person>", "person", caption)
|
||||
# urls:
|
||||
caption = re.sub(
|
||||
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
caption = re.sub(
|
||||
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
|
||||
"",
|
||||
caption,
|
||||
) # regex for urls
|
||||
# html:
|
||||
caption = BeautifulSoup(caption, features="html.parser").text
|
||||
|
||||
# @<nickname>
|
||||
caption = re.sub(r"@[\w\d]+\b", "", caption)
|
||||
|
||||
# 31C0—31EF CJK Strokes
|
||||
# 31F0—31FF Katakana Phonetic Extensions
|
||||
# 3200—32FF Enclosed CJK Letters and Months
|
||||
# 3300—33FF CJK Compatibility
|
||||
# 3400—4DBF CJK Unified Ideographs Extension A
|
||||
# 4DC0—4DFF Yijing Hexagram Symbols
|
||||
# 4E00—9FFF CJK Unified Ideographs
|
||||
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
|
||||
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
|
||||
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
|
||||
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
|
||||
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
|
||||
#######################################################
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
caption = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
|
||||
"-",
|
||||
caption,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
caption = re.sub(r"[`´«»“”¨]", '"', caption)
|
||||
caption = re.sub(r"[‘’]", "'", caption)
|
||||
|
||||
# "
|
||||
caption = re.sub(r""?", "", caption)
|
||||
# &
|
||||
caption = re.sub(r"&", "", caption)
|
||||
|
||||
# ip adresses:
|
||||
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
|
||||
|
||||
# article ids:
|
||||
caption = re.sub(r"\d:\d\d\s+$", "", caption)
|
||||
|
||||
# \n
|
||||
caption = re.sub(r"\\n", " ", caption)
|
||||
|
||||
# "#123"
|
||||
caption = re.sub(r"#\d{1,3}\b", "", caption)
|
||||
# "#12345.."
|
||||
caption = re.sub(r"#\d{5,}\b", "", caption)
|
||||
# "123456.."
|
||||
caption = re.sub(r"\b\d{6,}\b", "", caption)
|
||||
# filenames:
|
||||
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
|
||||
|
||||
#
|
||||
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
|
||||
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
|
||||
|
||||
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, caption)) > 3:
|
||||
caption = re.sub(regex2, " ", caption)
|
||||
|
||||
caption = ftfy.fix_text(caption)
|
||||
caption = html.unescape(html.unescape(caption))
|
||||
|
||||
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
|
||||
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
|
||||
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
|
||||
|
||||
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
|
||||
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
|
||||
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
|
||||
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
|
||||
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
|
||||
|
||||
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
|
||||
|
||||
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
|
||||
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
|
||||
caption = re.sub(r"\s+", " ", caption)
|
||||
|
||||
caption.strip()
|
||||
|
||||
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
|
||||
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
|
||||
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
|
||||
caption = re.sub(r"^\.\S+$", "", caption)
|
||||
|
||||
return caption.strip()
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_frames: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device],
|
||||
generator: torch.Generator,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents: torch.FloatTensor, num_frames: int, decode_chunk_size: int = 14):
|
||||
# [batch, channels, frames, height, width] -> [batch, frames, channels, height, width]
|
||||
latents = latents.permute(0, 2, 1, 3, 4)
|
||||
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
||||
latents = latents.flatten(0, 1)
|
||||
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
||||
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
||||
|
||||
# decode decode_chunk_size frames at a time to avoid OOM
|
||||
frames = []
|
||||
for i in range(0, latents.shape[0], decode_chunk_size):
|
||||
num_frames_in = latents[i : i + decode_chunk_size].shape[0]
|
||||
decode_kwargs = {}
|
||||
if accepts_num_frames:
|
||||
# we only pass num_frames_in if it's expected
|
||||
decode_kwargs["num_frames"] = num_frames_in
|
||||
|
||||
frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
|
||||
frames.append(frame)
|
||||
frames = torch.cat(frames, dim=0)
|
||||
|
||||
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
||||
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
frames = frames.float()
|
||||
return frames
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 20,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.5,
|
||||
min_guidance_scale: Optional[float] = None,
|
||||
max_guidance_scale: Optional[float] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
clean_caption: bool = True,
|
||||
max_sequence_length: int = 120,
|
||||
decode_chunk_size: int = 4,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
|
||||
timesteps are used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
||||
The width in pixels of the generated image.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
|
||||
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated attention mask for negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
clean_caption (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
||||
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
||||
prompt.
|
||||
use_resolution_binning (`bool` defaults to `True`):
|
||||
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
||||
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
|
||||
the requested resolution. Useful for generating non-square images.
|
||||
max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images
|
||||
"""
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
num_frames = self.transformer.config.sample_size[0]
|
||||
height = height or self.transformer.config.sample_size[1] * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size[2] * self.vae_scale_factor
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_steps,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_frames,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 8. Prepare guidance scale
|
||||
# TODO: Hacky for testing, make this cleaner
|
||||
if min_guidance_scale is not None and max_guidance_scale is not None:
|
||||
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
||||
guidance_scale = guidance_scale.to(device, latents.dtype)
|
||||
guidance_scale = guidance_scale.repeat(batch_size * num_images_per_prompt, 1)
|
||||
guidance_scale = _append_dims(guidance_scale, latents.ndim).transpose(1, 2)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
current_timestep = t
|
||||
if not torch.is_tensor(current_timestep):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = latent_model_input.device.type == "mps"
|
||||
if isinstance(current_timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
||||
elif len(current_timestep.shape) == 0:
|
||||
current_timestep = current_timestep[None].to(latent_model_input.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=current_timestep,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# learned sigma
|
||||
if self.transformer.config.out_channels // 2 == latent_channels:
|
||||
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
||||
else:
|
||||
noise_pred = noise_pred
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
||||
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
||||
else:
|
||||
frames = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (frames,)
|
||||
|
||||
return OpenSoraPipelineOutput(frames=frames)
|
||||
@@ -227,6 +227,21 @@ class Transformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class Transformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNet1DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from diffusers import Transformer3DModel, AutoencoderTiny, DPMSolverMultistepScheduler, OpenSoraPipeline
|
||||
|
||||
channels, num_frames, height, width, text_dim = 4, 2, 4, 4, 32
|
||||
|
||||
model = Transformer3DModel(
|
||||
in_channels=channels,
|
||||
out_channels=channels*2,
|
||||
cross_attention_dim=1408,
|
||||
caption_channels=text_dim,
|
||||
num_embeds_ada_norm=1000,
|
||||
sample_size=(num_frames, height, width),
|
||||
)
|
||||
|
||||
x = torch.randn(1, channels, num_frames, height, width)
|
||||
y = torch.randn(1, 77, 32)
|
||||
t = torch.ones(1)
|
||||
|
||||
# with torch.no_grad():
|
||||
# out = model(x, y, t)
|
||||
# print(out.sample.shape) # torch.Size([1, 8, 2, 4, 4])
|
||||
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
|
||||
vae = AutoencoderTiny.from_pretrained("madebyollin/taesd")
|
||||
scheduler = DPMSolverMultistepScheduler.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="scheduler")
|
||||
|
||||
pipe = OpenSoraPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=model,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
prompt = ""
|
||||
out = pipe(prompt, num_inference_steps=1, min_guidance_scale=1.0, max_guidance_scale=3.0)
|
||||
@@ -1144,20 +1144,24 @@ class PipelineTesterMixin:
|
||||
self.assertLess(
|
||||
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
|
||||
)
|
||||
offloaded_modules = [
|
||||
v
|
||||
offloaded_modules = {
|
||||
k: v
|
||||
for k, v in pipe.components.items()
|
||||
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
||||
]
|
||||
(
|
||||
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
|
||||
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
|
||||
}
|
||||
self.assertTrue(
|
||||
all(v.device.type == "cpu" for v in offloaded_modules.values()),
|
||||
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
|
||||
)
|
||||
|
||||
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
|
||||
(
|
||||
self.assertTrue(all(isinstance(v, accelerate.hooks.CpuOffload) for v in offloaded_modules_with_hooks)),
|
||||
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.CpuOffload)]}",
|
||||
offloaded_modules_with_incorrect_hooks = {}
|
||||
for k, v in offloaded_modules.items():
|
||||
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
|
||||
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
|
||||
|
||||
self.assertTrue(
|
||||
len(offloaded_modules_with_incorrect_hooks) == 0,
|
||||
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
@@ -1189,22 +1193,23 @@ class PipelineTesterMixin:
|
||||
self.assertLess(
|
||||
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
|
||||
)
|
||||
offloaded_modules = [
|
||||
v
|
||||
offloaded_modules = {
|
||||
k: v
|
||||
for k, v in pipe.components.items()
|
||||
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
|
||||
]
|
||||
(
|
||||
self.assertTrue(all(v.device.type == "meta" for v in offloaded_modules)),
|
||||
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'meta']}",
|
||||
}
|
||||
self.assertTrue(
|
||||
all(v.device.type == "meta" for v in offloaded_modules.values()),
|
||||
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
|
||||
)
|
||||
offloaded_modules_with_incorrect_hooks = {}
|
||||
for k, v in offloaded_modules.items():
|
||||
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
|
||||
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
|
||||
|
||||
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
|
||||
(
|
||||
self.assertTrue(
|
||||
all(isinstance(v, accelerate.hooks.AlignDevicesHook) for v in offloaded_modules_with_hooks)
|
||||
),
|
||||
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.AlignDevicesHook)]}",
|
||||
self.assertTrue(
|
||||
len(offloaded_modules_with_incorrect_hooks) == 0,
|
||||
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
|
||||
)
|
||||
|
||||
@unittest.skipIf(
|
||||
|
||||
Reference in New Issue
Block a user