Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a31df720fe | |||
| b09b90e24c | |||
| 058b47553e | |||
| 7f58a76f48 | |||
| 09b7bfce91 | |||
| 5d8b1987ec | |||
| acd1962769 | |||
| 5b1b80a5b6 | |||
| 8581d9bce4 | |||
| c101066227 | |||
| d4c7ab7bf1 | |||
| ea9dc3fa90 | |||
| b4220e97b1 | |||
| dc85b578c2 | |||
| 0d927c7542 | |||
| 5b93338235 | |||
| 7c1c705f60 | |||
| 9e72016468 | |||
| 3e9716f22b | |||
| 87bfbc320d | |||
| a517f665a4 | |||
| 16748d1eba | |||
| c9081a8abd | |||
| 0eb68d9ddb | |||
| 9941b3f124 | |||
| 16b9f98b48 | |||
| fee93c81eb | |||
| 5308cce994 | |||
| 318556b20e | |||
| 6620eda357 | |||
| 1f0705adcf | |||
| 5e96333cb2 | |||
| da95a28ff6 | |||
| d66d554dc2 | |||
| c7df846dec |
@@ -77,7 +77,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
|
||||
|
||||
## Quickstart
|
||||
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 16000+ checkpoints):
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 19000+ checkpoints):
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -219,7 +219,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
|
||||
- https://github.com/deep-floyd/IF
|
||||
- https://github.com/bentoml/BentoML
|
||||
- https://github.com/bmaltais/kohya_ss
|
||||
- +7000 other amazing GitHub repositories 💪
|
||||
- +8000 other amazing GitHub repositories 💪
|
||||
|
||||
Thank you for using us ❤️.
|
||||
|
||||
|
||||
@@ -228,6 +228,8 @@
|
||||
title: UNet3DConditionModel
|
||||
- local: api/models/unet-motion
|
||||
title: UNetMotionModel
|
||||
- local: api/models/uvit2d
|
||||
title: UViT2DModel
|
||||
- local: api/models/vq
|
||||
title: VQModel
|
||||
- local: api/models/autoencoderkl
|
||||
|
||||
@@ -30,8 +30,8 @@ To learn more about how to load single file weights, see the [Load different Sta
|
||||
|
||||
## FromOriginalVAEMixin
|
||||
|
||||
[[autodoc]] loaders.single_file.FromOriginalVAEMixin
|
||||
[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin
|
||||
|
||||
## FromOriginalControlnetMixin
|
||||
|
||||
[[autodoc]] loaders.single_file.FromOriginalControlnetMixin
|
||||
[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin
|
||||
@@ -24,4 +24,4 @@ The abstract from the paper is:
|
||||
|
||||
## PriorTransformerOutput
|
||||
|
||||
[[autodoc]] models.prior_transformer.PriorTransformerOutput
|
||||
[[autodoc]] models.transformers.prior_transformer.PriorTransformerOutput
|
||||
|
||||
@@ -38,4 +38,4 @@ It is assumed one of the input classes is the masked latent pixel. The predicted
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
|
||||
[[autodoc]] models.transformers.transformer_2d.Transformer2DModelOutput
|
||||
|
||||
@@ -16,8 +16,8 @@ A Transformer model for video-like data.
|
||||
|
||||
## TransformerTemporalModel
|
||||
|
||||
[[autodoc]] models.transformer_temporal.TransformerTemporalModel
|
||||
[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModel
|
||||
|
||||
## TransformerTemporalModelOutput
|
||||
|
||||
[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
|
||||
[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModelOutput
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNetMotionModel
|
||||
|
||||
## UNet3DConditionOutput
|
||||
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
|
||||
[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet1DModel
|
||||
|
||||
## UNet1DOutput
|
||||
[[autodoc]] models.unet_1d.UNet1DOutput
|
||||
[[autodoc]] models.unets.unet_1d.UNet1DOutput
|
||||
|
||||
@@ -22,10 +22,10 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet2DConditionModel
|
||||
|
||||
## UNet2DConditionOutput
|
||||
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
|
||||
[[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput
|
||||
|
||||
## FlaxUNet2DConditionModel
|
||||
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel
|
||||
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionModel
|
||||
|
||||
## FlaxUNet2DConditionOutput
|
||||
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
|
||||
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet2DModel
|
||||
|
||||
## UNet2DOutput
|
||||
[[autodoc]] models.unet_2d.UNet2DOutput
|
||||
[[autodoc]] models.unets.unet_2d.UNet2DOutput
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet3DConditionModel
|
||||
|
||||
## UNet3DConditionOutput
|
||||
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
|
||||
[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# UVit2DModel
|
||||
|
||||
The [U-ViT](https://hf.co/papers/2301.11093) model is a vision transformer (ViT) based UNet. This model incorporates elements from ViT (considers all inputs such as time, conditions and noisy image patches as tokens) and a UNet (long skip connections between the shallow and deep layers). The skip connection is important for predicting pixel-level features. An additional 3x3 convolutional block is applied prior to the final output to improve image quality.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Currently, applying diffusion models in pixel space of high resolution images is difficult. Instead, existing approaches focus on diffusion in lower dimensional spaces (latent diffusion), or have multiple super-resolution levels of generation referred to as cascades. The downside is that these approaches add additional complexity to the diffusion framework. This paper aims to improve denoising diffusion for high resolution images while keeping the model as simple as possible. The paper is centered around the research question: How can one train a standard denoising diffusion models on high resolution images, and still obtain performance comparable to these alternate approaches? The four main findings are: 1) the noise schedule should be adjusted for high resolution images, 2) It is sufficient to scale only a particular part of the architecture, 3) dropout should be added at specific locations in the architecture, and 4) downsampling is an effective strategy to avoid high resolution feature maps. Combining these simple yet effective techniques, we achieve state-of-the-art on image generation among diffusion models without sampling modifiers on ImageNet.*
|
||||
|
||||
## UVit2DModel
|
||||
|
||||
[[autodoc]] UVit2DModel
|
||||
|
||||
## UVit2DConvEmbed
|
||||
|
||||
[[autodoc]] models.unets.uvit_2d.UVit2DConvEmbed
|
||||
|
||||
## UVitBlock
|
||||
|
||||
[[autodoc]] models.unets.uvit_2d.UVitBlock
|
||||
|
||||
## ConvNextBlock
|
||||
|
||||
[[autodoc]] models.unets.uvit_2d.ConvNextBlock
|
||||
|
||||
## ConvMlmLayer
|
||||
|
||||
[[autodoc]] models.unets.uvit_2d.ConvMlmLayer
|
||||
@@ -25,6 +25,7 @@ The abstract of the paper is the following:
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
|
||||
| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
@@ -32,6 +33,8 @@ Motion Adapter checkpoints can be found under [guoyww](https://huggingface.co/gu
|
||||
|
||||
## Usage example
|
||||
|
||||
### AnimateDiffPipeline
|
||||
|
||||
AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet.
|
||||
|
||||
The following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
|
||||
@@ -98,6 +101,114 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
|
||||
|
||||
</Tip>
|
||||
|
||||
### AnimateDiffVideoToVideoPipeline
|
||||
|
||||
AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.
|
||||
|
||||
```python
|
||||
import imageio
|
||||
import requests
|
||||
import torch
|
||||
from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
|
||||
from diffusers.utils import export_to_gif
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
# Load the motion adapter
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
|
||||
# load SD 1.5 based finetuned model
|
||||
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
|
||||
pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to("cuda")
|
||||
scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
clip_sample=False,
|
||||
timestep_spacing="linspace",
|
||||
beta_schedule="linear",
|
||||
steps_offset=1,
|
||||
)
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
# enable memory savings
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# helper function to load videos
|
||||
def load_video(file_path: str):
|
||||
images = []
|
||||
|
||||
if file_path.startswith(('http://', 'https://')):
|
||||
# If the file_path is a URL
|
||||
response = requests.get(file_path)
|
||||
response.raise_for_status()
|
||||
content = BytesIO(response.content)
|
||||
vid = imageio.get_reader(content)
|
||||
else:
|
||||
# Assuming it's a local file path
|
||||
vid = imageio.get_reader(file_path)
|
||||
|
||||
for frame in vid:
|
||||
pil_image = Image.fromarray(frame)
|
||||
images.append(pil_image)
|
||||
|
||||
return images
|
||||
|
||||
video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
|
||||
|
||||
output = pipe(
|
||||
video = video,
|
||||
prompt="panda playing a guitar, on a boat, in the ocean, high quality",
|
||||
negative_prompt="bad quality, worse quality",
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=25,
|
||||
strength=0.5,
|
||||
generator=torch.Generator("cpu").manual_seed(42),
|
||||
)
|
||||
frames = output.frames[0]
|
||||
export_to_gif(frames, "animation.gif")
|
||||
```
|
||||
|
||||
Here are some sample outputs:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th align=center>Source Video</th>
|
||||
<th align=center>Output Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=center>
|
||||
raccoon playing a guitar
|
||||
<br />
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
|
||||
alt="racoon playing a guitar"
|
||||
style="width: 300px;" />
|
||||
</td>
|
||||
<td align=center>
|
||||
panda playing a guitar
|
||||
<br/>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-output-1.gif"
|
||||
alt="panda playing a guitar"
|
||||
style="width: 300px;" />
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=center>
|
||||
closeup of margot robbie, fireworks in the background, high quality
|
||||
<br />
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-2.gif"
|
||||
alt="closeup of margot robbie, fireworks in the background, high quality"
|
||||
style="width: 300px;" />
|
||||
</td>
|
||||
<td align=center>
|
||||
closeup of tony stark, robert downey jr, fireworks
|
||||
<br/>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-output-2.gif"
|
||||
alt="closeup of tony stark, robert downey jr, fireworks"
|
||||
style="width: 300px;" />
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Using Motion LoRAs
|
||||
|
||||
Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations.
|
||||
@@ -300,16 +411,14 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
## AnimateDiffPipeline
|
||||
|
||||
[[autodoc]] AnimateDiffPipeline
|
||||
- all
|
||||
- __call__
|
||||
- enable_freeu
|
||||
- disable_freeu
|
||||
- enable_free_init
|
||||
- disable_free_init
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_vae_tiling
|
||||
- disable_vae_tiling
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AnimateDiffVideoToVideoPipeline
|
||||
|
||||
[[autodoc]] AnimateDiffVideoToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AnimateDiffPipelineOutput
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ accelerate launch train_text_to_image_lora.py \
|
||||
|
||||
Many of the basic and important parameters are described in the [Text-to-image](text2image#script-parameters) training guide, so this guide just focuses on the LoRA relevant parameters:
|
||||
|
||||
- `--rank`: the number of low-rank matrices to train
|
||||
- `--rank`: the inner dimension of the low-rank matrices to train; a higher rank means more trainable parameters
|
||||
- `--learning_rate`: the default learning rate is 1e-4, but with LoRA, you can use a higher learning rate
|
||||
|
||||
## Training script
|
||||
|
||||
@@ -206,3 +206,13 @@ pipe.fuse_lora(adapter_names=["pixel", "toy"])
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
## Saving a pipeline after fusing the adapters
|
||||
|
||||
To properly save a pipeline after it's been loaded with the adapters, it should be serialized like so:
|
||||
|
||||
```python
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.save_pretrained("path-to-pipeline")
|
||||
```
|
||||
|
||||
@@ -11,4 +11,6 @@
|
||||
- sections:
|
||||
- local: tutorials/tutorial_overview
|
||||
title: 概要
|
||||
- local: tutorials/autopipeline
|
||||
title: AutoPipeline
|
||||
title: チュートリアル
|
||||
@@ -0,0 +1,168 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AutoPipeline
|
||||
|
||||
Diffusersは様々なタスクをこなすことができ、テキストから画像、画像から画像、画像の修復など、複数のタスクに対して同じように事前学習された重みを再利用することができます。しかし、ライブラリや拡散モデルに慣れていない場合、どのタスクにどのパイプラインを使えばいいのかがわかりにくいかもしれません。例えば、 [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) チェックポイントをテキストから画像に変換するために使用している場合、それぞれ[`StableDiffusionImg2ImgPipeline`]クラスと[`StableDiffusionInpaintPipeline`]クラスでチェックポイントをロードすることで、画像から画像や画像の修復にも使えることを知らない可能性もあります。
|
||||
|
||||
`AutoPipeline` クラスは、🤗 Diffusers の様々なパイプラインをよりシンプルするために設計されています。この汎用的でタスク重視のパイプラインによってタスクそのものに集中することができます。`AutoPipeline` は、使用するべき正しいパイプラインクラスを自動的に検出するため、特定のパイプラインクラス名を知らなくても、タスクのチェックポイントを簡単にロードできます。
|
||||
|
||||
<Tip>
|
||||
|
||||
どのタスクがサポートされているかは、[AutoPipeline](../api/pipelines/auto_pipeline) のリファレンスをご覧ください。現在、text-to-image、image-to-image、inpaintingをサポートしています。
|
||||
|
||||
</Tip>
|
||||
|
||||
このチュートリアルでは、`AutoPipeline` を使用して、事前に学習された重みが与えられたときに、特定のタスクを読み込むためのパイプラインクラスを自動的に推測する方法を示します。
|
||||
|
||||
## タスクに合わせてAutoPipeline を選択する
|
||||
まずはチェックポイントを選ぶことから始めましょう。例えば、 [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) チェックポイントでテキストから画像への変換したいなら、[`AutoPipelineForText2Image`]を使います:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
prompt = "peasant and dragon combat, wood cutting style, viking era, bevel with rune"
|
||||
|
||||
image = pipeline(prompt, num_inference_steps=25).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png" alt="generated image of peasant fighting dragon in wood cutting style"/>
|
||||
</div>
|
||||
|
||||
[`AutoPipelineForText2Image`] を具体的に見ていきましょう:
|
||||
|
||||
1. [`model_index.json`](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json) ファイルから `"stable-diffusion"` クラスを自動的に検出します。
|
||||
2. `"stable-diffusion"` のクラス名に基づいて、テキストから画像へ変換する [`StableDiffusionPipeline`] を読み込みます。
|
||||
|
||||
同様に、画像から画像へ変換する場合、[`AutoPipelineForImage2Image`] は `model_index.json` ファイルから `"stable-diffusion"` チェックポイントを検出し、対応する [`StableDiffusionImg2ImgPipeline`] を読み込みます。また、入力画像にノイズの量やバリエーションの追加を決めるための強さなど、パイプラインクラスに固有の追加引数を渡すこともできます:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
prompt = "a portrait of a dog wearing a pearl earring"
|
||||
|
||||
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0f/1665_Girl_with_a_Pearl_Earring.jpg/800px-1665_Girl_with_a_Pearl_Earring.jpg"
|
||||
|
||||
response = requests.get(url)
|
||||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
image.thumbnail((768, 768))
|
||||
|
||||
image = pipeline(prompt, image, num_inference_steps=200, strength=0.75, guidance_scale=10.5).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png" alt="generated image of a vermeer portrait of a dog wearing a pearl earring"/>
|
||||
</div>
|
||||
|
||||
また、画像の修復を行いたい場合は、 [`AutoPipelineForInpainting`] が、同様にベースとなる[`StableDiffusionInpaintPipeline`]クラスを読み込みます:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = load_image(img_url).convert("RGB")
|
||||
mask_image = load_image(mask_url).convert("RGB")
|
||||
|
||||
prompt = "A majestic tiger sitting on a bench"
|
||||
image = pipeline(prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png" alt="generated image of a tiger sitting on a bench"/>
|
||||
</div>
|
||||
|
||||
サポートされていないチェックポイントを読み込もうとすると、エラーになります:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"openai/shap-e-img2img", torch_dtype=torch.float16, use_safetensors=True
|
||||
)
|
||||
"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
|
||||
```
|
||||
|
||||
## 複数のパイプラインを使用する
|
||||
|
||||
いくつかのワークフローや多くのパイプラインを読み込む場合、不要なメモリを使ってしまう再読み込みをするよりも、チェックポイントから同じコンポーネントを再利用する方がメモリ効率が良いです。たとえば、テキストから画像への変換にチェックポイントを使い、画像から画像への変換にまたチェックポイントを使いたい場合、[from_pipe()](https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe) メソッドを使用します。このメソッドは、以前読み込まれたパイプラインのコンポーネントを使うことで追加のメモリを消費することなく、新しいパイプラインを作成します。
|
||||
|
||||
[from_pipe()](https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe) メソッドは、元のパイプラインクラスを検出し、実行したいタスクに対応する新しいパイプラインクラスにマッピングします。例えば、テキストから画像への`"stable-diffusion"` クラスのパイプラインを読み込む場合:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||||
import torch
|
||||
|
||||
pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
)
|
||||
print(type(pipeline_text2img))
|
||||
"<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'>"
|
||||
```
|
||||
|
||||
そして、[from_pipe()] (https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe)は、もとの`"stable-diffusion"` パイプラインのクラスである [`StableDiffusionImg2ImgPipeline`] にマップします:
|
||||
|
||||
```py
|
||||
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)
|
||||
print(type(pipeline_img2img))
|
||||
"<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline'>"
|
||||
```
|
||||
元のパイプラインにオプションとして引数(セーフティチェッカーの無効化など)を渡した場合、この引数も新しいパイプラインに渡されます:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||||
import torch
|
||||
|
||||
pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
requires_safety_checker=False,
|
||||
).to("cuda")
|
||||
|
||||
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)
|
||||
print(pipeline_img2img.config.requires_safety_checker)
|
||||
"False"
|
||||
```
|
||||
|
||||
新しいパイプラインの動作を変更したい場合は、元のパイプラインの引数や設定を上書きすることができます。例えば、セーフティチェッカーをオンに戻し、`strength` 引数を追加します:
|
||||
|
||||
```py
|
||||
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img, requires_safety_checker=True, strength=0.3)
|
||||
print(pipeline_img2img.config.requires_safety_checker)
|
||||
"True"
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
@@ -27,8 +27,8 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
|
||||
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
|
||||
MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
|
||||
| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
|
||||
| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
|
||||
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) |
|
||||
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
@@ -44,9 +44,9 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
|
||||
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
|
||||
| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
|
||||
FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
|
||||
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
|
||||
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
@@ -55,14 +55,17 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
|
||||
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
|
||||
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
|
||||
| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
|
||||
| Instaflow Pipeline | Implementation of [InstaFlow! One-Step Stable Diffusion with Rectified Flow](https://arxiv.org/abs/2309.06380) | [Instaflow Pipeline](#instaflow-pipeline) | - | [Ayush Mangal](https://github.com/ayushtues) |
|
||||
| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
|
||||
| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender_A_Video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
|
||||
| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
|
||||
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
|
||||
| InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) |
|
||||
| UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="filename_in_the_community_folder")
|
||||
```
|
||||
@@ -3228,6 +3231,43 @@ output_image.save("./output.png")
|
||||
|
||||
```
|
||||
|
||||
### Instaflow Pipeline
|
||||
InstaFlow is an ultra-fast, one-step image generator that achieves image quality close to Stable Diffusion, significantly reducing the demand of computational resources. This efficiency is made possible through a recent [Rectified Flow](https://github.com/gnobitab/RectifiedFlow) technique, which trains probability flows with straight trajectories, hence inherently requiring only a single step for fast inference.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("XCLIU/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float16, custom_pipeline="instaflow_one_step")
|
||||
pipe.to("cuda") ### if GPU is not available, comment this line
|
||||
prompt = "A hyper-realistic photo of a cute cat."
|
||||
|
||||
images = pipe(prompt=prompt,
|
||||
num_inference_steps=1,
|
||||
guidance_scale=0.0).images
|
||||
images[0].save("./image.png")
|
||||
```
|
||||

|
||||
|
||||
You can also combine it with LORA out of the box, like https://huggingface.co/artificialguybr/logo-redmond-1-5v-logo-lora-for-liberteredmond-sd-1-5, to unlock cool use cases in single step!
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("XCLIU/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float16, custom_pipeline="instaflow_one_step")
|
||||
pipe.to("cuda") ### if GPU is not available, comment this line
|
||||
pipe.load_lora_weights("artificialguybr/logo-redmond-1-5v-logo-lora-for-liberteredmond-sd-1-5")
|
||||
prompt = "logo, A logo for a fitness app, dynamic running figure, energetic colors (red, orange) ),LogoRedAF ,"
|
||||
images = pipe(prompt=prompt,
|
||||
num_inference_steps=1,
|
||||
guidance_scale=0.0).images
|
||||
images[0].save("./image.png")
|
||||
```
|
||||

|
||||
|
||||
### Null-Text Inversion pipeline
|
||||
|
||||
This pipeline provides null-text inversion for editing real images. It enables null-text optimization, and DDIM reconstruction via w, w/o null-text optimization. No prompt-to-prompt code is implemented as there is a Prompt2PromptPipeline.
|
||||
@@ -3265,8 +3305,10 @@ pipeline = NullTextPipeline.from_pretrained(model_path, scheduler = scheduler, t
|
||||
inverted_latent, uncond = pipeline.invert(input_image, invert_prompt, num_inner_steps=10, early_stop_epsilon= 1e-5, num_inference_steps = steps)
|
||||
pipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_steps=steps).images[0].save(input_image+".output.jpg")
|
||||
```
|
||||
|
||||
### Rerender_A_Video
|
||||
|
||||
```
|
||||
This is the Diffusers implementation of zero-shot video-to-video translation pipeline [Rerender_A_Video](https://github.com/williamyang1991/Rerender_A_Video) (without Ebsynth postprocessing). To run the code, please install gmflow. Then modify the path in `examples/community/rerender_a_video.py`:
|
||||
|
||||
```py
|
||||
@@ -3334,7 +3376,6 @@ generator = torch.manual_seed(0)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
output_frames = pipe(
|
||||
"a beautiful woman in CG style, best quality, extremely detailed",
|
||||
|
||||
frames,
|
||||
control_frames,
|
||||
num_inference_steps=20,
|
||||
@@ -3355,7 +3396,7 @@ export_to_video(
|
||||
|
||||
### StyleAligned Pipeline
|
||||
|
||||
This pipeline is the implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133).
|
||||
This pipeline is the implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133). You can find more results [here](https://github.com/huggingface/diffusers/pull/6489#issuecomment-1881209354).
|
||||
|
||||
> Large-scale Text-to-Image (T2I) models have rapidly gained prominence across creative fields, generating visually compelling outputs from textual prompts. However, controlling these models to ensure consistent style remains challenging, with existing methods necessitating fine-tuning and manual intervention to disentangle content and style. In this paper, we introduce StyleAligned, a novel technique designed to establish style alignment among a series of generated images. By employing minimal `attention sharing' during the diffusion process, our method maintains style consistency across images within T2I models. This approach allows for the creation of style-consistent images using a reference style through a straightforward inversion operation. Our method's evaluation across diverse styles and text prompts demonstrates high-quality synthesis and fidelity, underscoring its efficacy in achieving consistent style across various inputs.
|
||||
|
||||
@@ -3441,6 +3482,7 @@ IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings
|
||||
You need to install `insightface` and all its requirements to use this model.
|
||||
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
|
||||
You have to disable PEFT BACKEND in order to load weights.
|
||||
You can find more results [here](https://github.com/huggingface/diffusers/pull/6276).
|
||||
|
||||
```py
|
||||
import diffusers
|
||||
@@ -3494,3 +3536,102 @@ images = pipeline(
|
||||
for i in range(num_images):
|
||||
images[i].save(f"c{i}.png")
|
||||
```
|
||||
|
||||
### InstantID Pipeline
|
||||
|
||||
InstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving generation with only single image, supporting various downstream tasks. For any usgae question, please refer to the [official implementation](https://github.com/InstantID/InstantID).
|
||||
|
||||
```py
|
||||
# !pip install opencv-python transformers accelerate insightface
|
||||
import diffusers
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.models import ControlNetModel
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface.app import FaceAnalysis
|
||||
from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
|
||||
|
||||
# prepare 'antelopev2' under ./models
|
||||
# https://github.com/deepinsight/insightface/issues/1896#issuecomment-1023867304
|
||||
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
|
||||
# prepare models under ./checkpoints
|
||||
# https://huggingface.co/InstantX/InstantID
|
||||
from huggingface_hub import hf_hub_download
|
||||
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
|
||||
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
|
||||
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
|
||||
|
||||
face_adapter = f'./checkpoints/ip-adapter.bin'
|
||||
controlnet_path = f'./checkpoints/ControlNetModel'
|
||||
|
||||
# load IdentityNet
|
||||
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
||||
|
||||
base_model = 'wangqixun/YamerMIX_v8'
|
||||
pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe.cuda()
|
||||
|
||||
# load adapter
|
||||
pipe.load_ip_adapter_instantid(face_adapter)
|
||||
|
||||
# load an image
|
||||
face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
|
||||
|
||||
# prepare face emb
|
||||
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
|
||||
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
|
||||
face_emb = face_info['embedding']
|
||||
face_kps = draw_kps(face_image, face_info['kps'])
|
||||
|
||||
# prompt
|
||||
prompt = "film noir style, ink sketch|vector, male man, highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic"
|
||||
negative_prompt = "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, vibrant, colorful"
|
||||
|
||||
# generate image
|
||||
pipe.set_ip_adapter_scale(0.8)
|
||||
image = pipe(
|
||||
prompt,
|
||||
image_embeds=face_emb,
|
||||
image=face_kps,
|
||||
controlnet_conditioning_scale=0.8,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### UFOGen Scheduler
|
||||
|
||||
[UFOGen](https://arxiv.org/abs/2311.09257) is a generative model designed for fast one-step text-to-image generation, trained via adversarial training starting from an initial pretrained diffusion model such as Stable Diffusion. `scheduling_ufogen.py` implements a onestep and multistep sampling algorithm for UFOGen models compatible with pipelines like `StableDiffusionPipeline`. A usage example is as follows:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
from scheduling_ufogen import UFOGenScheduler
|
||||
|
||||
# NOTE: currently, I am not aware of any publicly available UFOGen model checkpoints trained from SD v1.5.
|
||||
ufogen_model_id_or_path = "/path/to/ufogen/model"
|
||||
pipe = StableDiffusionPipeline(
|
||||
ufogen_model_id_or_path,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
# You can initialize a UFOGenScheduler as follows:
|
||||
pipe.scheduler = UFOGenScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
prompt = "Three cats having dinner at a table at new years eve, cinematic shot, 8k."
|
||||
|
||||
# Onestep sampling
|
||||
onestep_image = pipe(prompt, num_inference_steps=1).images[0]
|
||||
|
||||
# Multistep sampling
|
||||
multistep_image = pipe(prompt, num_inference_steps=4).images[0]
|
||||
```
|
||||
|
||||
@@ -0,0 +1,707 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
deprecate,
|
||||
logging,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class InstaFlowPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Rectified Flow and Euler discretization.
|
||||
This customized pipeline is based on StableDiffusionPipeline from the official Diffusers library (0.21.4)
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.CLIPTextModel`]):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer ([`~transformers.CLIPTokenizer`]):
|
||||
A `CLIPTokenizer` to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
||||
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def merge_dW_to_unet(pipe, dW_dict, alpha=1.0):
|
||||
_tmp_sd = pipe.unet.state_dict()
|
||||
for key in dW_dict.keys():
|
||||
_tmp_sd[key] += dW_dict[key] * alpha
|
||||
pipe.unet.load_state_dict(_tmp_sd, strict=False)
|
||||
return pipe
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
||||
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps = [(1.0 - i / num_inference_steps) * 1000.0 for i in range(num_inference_steps)]
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
dt = 1.0 / num_inference_steps
|
||||
|
||||
# 7. Denoising loop of Euler discretization from t = 0 to t = 1
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
vec_t = torch.ones((latent_model_input.shape[0],), device=latents.device) * t
|
||||
|
||||
v_pred = self.unet(latent_model_input, vec_t, encoder_hidden_states=prompt_embeds).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
v_pred_neg, v_pred_text = v_pred.chunk(2)
|
||||
v_pred = v_pred_neg + guidance_scale * (v_pred_text - v_pred_neg)
|
||||
|
||||
latents = latents + dt * v_pred
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -26,7 +26,7 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.models.unet_motion_model import MotionAdapter
|
||||
from diffusers.models.unets.unet_motion_model import MotionAdapter
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import (
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,525 @@
|
||||
# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UFOGen
|
||||
class UFOGenSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.FloatTensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
||||
"""
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
|
||||
return betas
|
||||
|
||||
|
||||
class UFOGenScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
`UFOGenScheduler` implements multistep and onestep sampling for a UFOGen model, introduced in
|
||||
[UFOGen: You Forward Once Large Scale Text-to-Image Generation via Diffusion GANs](https://arxiv.org/abs/2311.09257)
|
||||
by Yanwu Xu, Yang Zhao, Zhisheng Xiao, and Tingbo Hou. UFOGen is a varianet of the denoising diffusion GAN (DDGAN)
|
||||
model designed for one-step sampling.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
clip_sample (`bool`, defaults to `True`):
|
||||
Clip the predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, defaults to 1.0):
|
||||
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
set_alpha_to_one (`bool`, defaults to `True`):
|
||||
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
||||
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the alpha value at step 0.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
||||
Diffusion.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
denoising_step_size (`int`, defaults to 250):
|
||||
The denoising step size parameter from the UFOGen paper. The number of steps used for training is roughly
|
||||
`math.ceil(num_train_timesteps / denoising_step_size)`.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
denoising_step_size: int = 250,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
elif beta_schedule == "sigmoid":
|
||||
# GeoDiff sigmoid schedule
|
||||
betas = torch.linspace(-6, 6, num_train_timesteps)
|
||||
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
# Rescale for zero SNR
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.custom_timesteps = False
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
|
||||
`num_inference_steps` must be `None`.
|
||||
|
||||
"""
|
||||
if num_inference_steps is not None and timesteps is not None:
|
||||
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
||||
|
||||
if timesteps is not None:
|
||||
for i in range(1, len(timesteps)):
|
||||
if timesteps[i] >= timesteps[i - 1]:
|
||||
raise ValueError("`custom_timesteps` must be in descending order.")
|
||||
|
||||
if timesteps[0] >= self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`timesteps` must start before `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps}."
|
||||
)
|
||||
|
||||
timesteps = np.array(timesteps, dtype=np.int64)
|
||||
self.custom_timesteps = True
|
||||
else:
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.custom_timesteps = False
|
||||
|
||||
# TODO: For now, handle special case when num_inference_steps == 1 separately
|
||||
if num_inference_steps == 1:
|
||||
# Set the timestep schedule to num_train_timesteps - 1 rather than 0
|
||||
# (that is, the one-step timestep schedule is always trailing rather than leading or linspace)
|
||||
timesteps = np.array([self.config.num_train_timesteps - 1], dtype=np.int64)
|
||||
else:
|
||||
# TODO: For now, retain the DDPM timestep spacing logic
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = (
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
|
||||
.round()[::-1]
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
||||
)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||||
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, *remaining_dims = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||||
s = torch.clamp(
|
||||
s, min=1, max=self.config.sample_max_value
|
||||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UFOGenSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_ddpm.UFOGenSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
# 0. Resolve timesteps
|
||||
t = timestep
|
||||
prev_t = self.previous_timestep(t)
|
||||
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
# beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
# current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
||||
# current_beta_t = 1 - current_alpha_t
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
||||
" `v_prediction` for UFOGenScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip or threshold "predicted x_0"
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
elif self.config.clip_sample:
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
# 4. Single-step or multi-step sampling
|
||||
# Noise is not used on the final timestep of the timestep schedule.
|
||||
# This also means that noise is not used for one-step sampling.
|
||||
if t != self.timesteps[-1]:
|
||||
# TODO: is this correct?
|
||||
# Sample prev sample x_{t - 1} ~ q(x_{t - 1} | x_0 = G(x_t, t))
|
||||
device = model_output.device
|
||||
noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
|
||||
sqrt_alpha_prod_t_prev = alpha_prod_t_prev**0.5
|
||||
sqrt_one_minus_alpha_prod_t_prev = (1 - alpha_prod_t_prev) ** 0.5
|
||||
pred_prev_sample = sqrt_alpha_prod_t_prev * pred_original_sample + sqrt_one_minus_alpha_prod_t_prev * noise
|
||||
else:
|
||||
# Simply return the pred_original_sample. If `prediction_type == "sample"`, this is equivalent to returning
|
||||
# the output of the GAN generator U-Net on the initial noisy latents x_T ~ N(0, I).
|
||||
pred_prev_sample = pred_original_sample
|
||||
|
||||
if not return_dict:
|
||||
return (pred_prev_sample,)
|
||||
|
||||
return UFOGenSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
||||
def get_velocity(
|
||||
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
||||
timesteps = timesteps.to(sample.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||
def previous_timestep(self, timestep):
|
||||
if self.custom_timesteps:
|
||||
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||||
if index == self.timesteps.shape[0] - 1:
|
||||
prev_t = torch.tensor(-1)
|
||||
else:
|
||||
prev_t = self.timesteps[index + 1]
|
||||
else:
|
||||
num_inference_steps = (
|
||||
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
)
|
||||
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
||||
|
||||
return prev_t
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from diffusers import StableDiffusionControlNetPipeline
|
||||
from diffusers.models import ControlNetModel
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import logging
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
from diffusers.models.unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
|
||||
@@ -907,10 +907,12 @@ def main():
|
||||
|
||||
if args.snr_gamma is not None:
|
||||
snr = jnp.array(compute_snr(timesteps))
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma)
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
snr_loss_weights = snr_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
snr_loss_weights = snr_loss_weights / (snr + 1)
|
||||
|
||||
loss = loss * snr_loss_weights
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -753,7 +753,7 @@ def main(args):
|
||||
num_new_images = args.num_class_images - cur_class_images
|
||||
logger.info(f"Number of class images to sample: {num_new_images}.")
|
||||
|
||||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
||||
|
||||
sample_dataloader = accelerator.prepare(sample_dataloader)
|
||||
|
||||
@@ -781,12 +781,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -631,12 +631,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -664,12 +664,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -811,12 +811,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
# Consistency Training
|
||||
|
||||
`train_cm_ct_unconditional.py` trains a consistency model (CM) from scratch following the consistency training (CT) algorithm introduced in [Consistency Models](https://arxiv.org/abs/2303.01469) and refined in [Improved Techniques for Training Consistency Models](https://arxiv.org/abs/2310.14189). Both unconditional and class-conditional training are supported.
|
||||
|
||||
A usage example is as follows:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/research_projects/consistency_training/train_cm_ct_unconditional.py \
|
||||
--dataset_name="cifar10" \
|
||||
--dataset_image_column_name="img" \
|
||||
--output_dir="/path/to/output/dir" \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=32 \
|
||||
--max_train_steps=1000 --max_train_samples=10000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--noise_precond_type="cm" --input_precond_type="cm" \
|
||||
--train_batch_size=4 \
|
||||
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--use_8bit_adam \
|
||||
--use_ema \
|
||||
--validation_steps=100 --eval_batch_size=4 \
|
||||
--checkpointing_steps=100 --checkpoints_total_limit=10 \
|
||||
--class_conditional --num_classes=10 \
|
||||
```
|
||||
@@ -0,0 +1,6 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProc
|
||||
from diffusers.models.autoencoders import AutoencoderKL
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
from diffusers.models.unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
@@ -36,7 +36,7 @@ from diffusers.models.unet_2d_blocks import (
|
||||
UpBlock2D,
|
||||
Upsample2D,
|
||||
)
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
|
||||
|
||||
|
||||
@@ -1041,11 +1041,6 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# manually for max memory savings
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
@@ -740,6 +740,10 @@ def main(args):
|
||||
# Resize.
|
||||
combined_im = train_resize(combined_im)
|
||||
|
||||
# Flipping.
|
||||
if not args.no_hflip and random.random() < 0.5:
|
||||
combined_im = train_flip(combined_im)
|
||||
|
||||
# Cropping.
|
||||
if not args.random_crop:
|
||||
y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
|
||||
@@ -749,11 +753,6 @@ def main(args):
|
||||
y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
|
||||
combined_im = crop(combined_im, y1, x1, h, w)
|
||||
|
||||
# Flipping.
|
||||
if random.random() < 0.5:
|
||||
x1 = combined_im.shape[2] - x1
|
||||
combined_im = train_flip(combined_im)
|
||||
|
||||
crop_top_left = (y1, x1)
|
||||
crop_top_lefts.append(crop_top_left)
|
||||
combined_im = normalize(combined_im)
|
||||
|
||||
@@ -848,12 +848,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -943,12 +943,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -759,12 +759,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -35,7 +35,7 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft import LoraConfig, set_peft_model_state_dict
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
@@ -51,8 +51,13 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -629,14 +634,6 @@ def main(args):
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
@@ -693,18 +690,34 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
lora_state_dict, _ = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
_set_state_dict_into_text_encoder(
|
||||
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one_, text_encoder_two_])
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
@@ -725,6 +738,13 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
@@ -1062,12 +1082,13 @@ def main(args):
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -1087,12 +1087,13 @@ def main(args):
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -10,7 +10,7 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import VQModel
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.uvit_2d import UVit2DModel
|
||||
from diffusers.models.unets.uvit_2d import UVit2DModel
|
||||
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
|
||||
from diffusers.schedulers import AmusedScheduler
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from tqdm import tqdm
|
||||
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
|
||||
from diffusers.models.autoencoders.vae import Encoder
|
||||
from diffusers.models.embeddings import TimestepEmbedding
|
||||
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
|
||||
from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
|
||||
|
||||
|
||||
args = ArgumentParser()
|
||||
|
||||
@@ -6,7 +6,7 @@ from accelerate import load_checkpoint_and_dispatch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
||||
from diffusers.models.vq_model import VQModel
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import tempfile
|
||||
import torch
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
||||
from diffusers.pipelines.shap_e import ShapERenderer
|
||||
|
||||
|
||||
|
||||
@@ -153,6 +153,7 @@ else:
|
||||
"LCMScheduler",
|
||||
"PNDMScheduler",
|
||||
"RePaintScheduler",
|
||||
"SASolverScheduler",
|
||||
"SchedulerMixin",
|
||||
"ScoreSdeVeScheduler",
|
||||
"UnCLIPScheduler",
|
||||
@@ -207,6 +208,7 @@ else:
|
||||
"AmusedInpaintPipeline",
|
||||
"AmusedPipeline",
|
||||
"AnimateDiffPipeline",
|
||||
"AnimateDiffVideoToVideoPipeline",
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
@@ -381,7 +383,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
|
||||
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
|
||||
_import_structure["schedulers"].extend(
|
||||
@@ -530,6 +532,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LCMScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SASolverScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
UnCLIPScheduler,
|
||||
@@ -567,6 +570,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AmusedInpaintPipeline,
|
||||
AmusedPipeline,
|
||||
AnimateDiffPipeline,
|
||||
AnimateDiffVideoToVideoPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
@@ -709,7 +713,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .models.controlnet_flax import FlaxControlNetModel
|
||||
from .models.modeling_flax_utils import FlaxModelMixin
|
||||
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.vae_flax import FlaxAutoencoderKL
|
||||
from .pipelines import FlaxDiffusionPipeline
|
||||
from .schedulers import (
|
||||
|
||||
@@ -16,7 +16,7 @@ import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from ...models.unet_1d import UNet1DModel
|
||||
from ...models.unets.unet_1d import UNet1DModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...utils.dummy_pt_objects import DDPMScheduler
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
@@ -54,12 +54,13 @@ if is_transformers_available():
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
|
||||
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
|
||||
|
||||
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
|
||||
if is_transformers_available():
|
||||
_import_structure["single_file"].extend(["FromSingleFileMixin"])
|
||||
_import_structure["single_file"] = ["FromSingleFileMixin"]
|
||||
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
@@ -69,7 +70,8 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
|
||||
from .autoencoder import FromOriginalVAEMixin
|
||||
from .controlnet import FromOriginalControlNetMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from .single_file_utils import (
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
"""
|
||||
Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
|
||||
a VAE from SDXL or a Stable Diffusion v2 model or higher.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
|
||||
vae = component["vae"]
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
return vae
|
||||
@@ -0,0 +1,127 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from .single_file_utils import (
|
||||
create_diffusers_controlnet_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class FromOriginalControlNetMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
upcast_attention = kwargs.pop("upcast_attention", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
|
||||
component = create_diffusers_controlnet_model_from_ldm(
|
||||
class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size
|
||||
)
|
||||
controlnet = component["controlnet"]
|
||||
if torch_dtype is not None:
|
||||
controlnet = controlnet.to(torch_dtype)
|
||||
|
||||
return controlnet
|
||||
@@ -11,39 +11,132 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import deprecate, is_accelerate_available, is_transformers_available, logging
|
||||
from ..utils import is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
create_diffusers_unet_model_from_ldm,
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
create_scheduler_from_ldm,
|
||||
create_text_encoders_and_tokenizers_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
infer_model_type,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
pass
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Pipelines that support the SDXL Refiner checkpoint
|
||||
REFINER_PIPELINES = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
def build_sub_model_components(
|
||||
pipeline_components,
|
||||
pipeline_class_name,
|
||||
component_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
local_files_only=False,
|
||||
load_safety_checker=False,
|
||||
model_type=None,
|
||||
image_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
if component_name in pipeline_components:
|
||||
return {}
|
||||
|
||||
if component_name == "unet":
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
unet_components = create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size
|
||||
)
|
||||
return unet_components
|
||||
|
||||
if component_name == "vae":
|
||||
vae_components = create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size
|
||||
)
|
||||
return vae_components
|
||||
|
||||
if component_name == "scheduler":
|
||||
scheduler_type = kwargs.get("scheduler_type", "ddim")
|
||||
prediction_type = kwargs.get("prediction_type", None)
|
||||
|
||||
scheduler_components = create_scheduler_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
scheduler_type=scheduler_type,
|
||||
prediction_type=prediction_type,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
return scheduler_components
|
||||
|
||||
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
|
||||
text_encoder_components = create_text_encoders_and_tokenizers_from_ldm(
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return text_encoder_components
|
||||
|
||||
if component_name == "safety_checker":
|
||||
if load_safety_checker:
|
||||
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
safety_checker = None
|
||||
return {"safety_checker": safety_checker}
|
||||
|
||||
if component_name == "feature_extractor":
|
||||
if load_safety_checker:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
feature_extractor = None
|
||||
return {"feature_extractor": feature_extractor}
|
||||
|
||||
return
|
||||
|
||||
|
||||
def set_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
model_type=None,
|
||||
):
|
||||
components = {}
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
{
|
||||
"requires_aesthetics_score": is_refiner,
|
||||
"force_zeros_for_empty_prompt": False if is_refiner else True,
|
||||
}
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, *args, **kwargs):
|
||||
deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
|
||||
deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
|
||||
return cls.from_single_file(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
@@ -58,8 +151,7 @@ class FromSingleFileMixin:
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@@ -85,42 +177,6 @@ class FromSingleFileMixin:
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`):
|
||||
Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
|
||||
higher quality images for inference. Non-EMA weights are usually better for continuing finetuning.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
prediction_type (`str`, *optional*):
|
||||
The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
|
||||
the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
|
||||
num_in_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of input channels. If `None`, it is automatically inferred.
|
||||
scheduler_type (`str`, *optional*, defaults to `"pndm"`):
|
||||
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
||||
"ddim"]`.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
||||
Whether to load the safety checker or not.
|
||||
text_encoder ([`~transformers.CLIPTextModel`], *optional*, defaults to `None`):
|
||||
An instance of `CLIPTextModel` to use, specifically the
|
||||
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
|
||||
parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
|
||||
vae (`AutoencoderKL`, *optional*, defaults to `None`):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
|
||||
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
|
||||
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
|
||||
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
|
||||
of `CLIPTokenizer` by itself if needed.
|
||||
original_config_file (`str`):
|
||||
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be
|
||||
automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
@@ -143,484 +199,80 @@ class FromSingleFileMixin:
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_files = kwargs.pop("config_files", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
text_encoder = kwargs.pop("text_encoder", None)
|
||||
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
||||
vae = kwargs.pop("vae", None)
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
adapter = kwargs.pop("adapter", None)
|
||||
tokenizer = kwargs.pop("tokenizer", None)
|
||||
tokenizer_2 = kwargs.pop("tokenizer_2", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
class_name = cls.__name__
|
||||
|
||||
pipeline_name = cls.__name__
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# TODO: For now we only support stable diffusion
|
||||
stable_unclip = None
|
||||
model_type = None
|
||||
|
||||
if pipeline_name in [
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
]:
|
||||
from ..models.controlnet import ControlNetModel
|
||||
from ..pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
|
||||
# list/tuple or a single instance of ControlNetModel or MultiControlNetModel
|
||||
if not (
|
||||
isinstance(controlnet, (ControlNetModel, MultiControlNetModel))
|
||||
or isinstance(controlnet, (list, tuple))
|
||||
and isinstance(controlnet[0], ControlNetModel)
|
||||
):
|
||||
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
# Model type will be inferred from the checkpoint.
|
||||
pass
|
||||
elif pipeline_name == "StableUnCLIPPipeline":
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "txt2img"
|
||||
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "img2img"
|
||||
elif pipeline_name == "PaintByExamplePipeline":
|
||||
model_type = "PaintByExample"
|
||||
elif pipeline_name == "LDMTextToImagePipeline":
|
||||
model_type = "LDMTextToImage"
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
# remove huggingface url
|
||||
has_valid_url_prefix = False
|
||||
valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
||||
for prefix in valid_url_prefixes:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
has_valid_url_prefix = True
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
if not has_valid_url_prefix:
|
||||
raise ValueError(
|
||||
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}"
|
||||
)
|
||||
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
pipeline_class=cls,
|
||||
model_type=model_type,
|
||||
stable_unclip=stable_unclip,
|
||||
controlnet=controlnet,
|
||||
adapter=adapter,
|
||||
from_safetensors=from_safetensors,
|
||||
extract_ema=extract_ema,
|
||||
image_size=image_size,
|
||||
scheduler_type=scheduler_type,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=load_safety_checker,
|
||||
prediction_type=prediction_type,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
config_files=config_files,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config=None,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
model_type = kwargs.pop("model_type", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
load_safety_checker = (kwargs.pop("load_safety_checker", False)) or (
|
||||
passed_class_obj.get("safety_checker", None) is not None
|
||||
)
|
||||
|
||||
init_kwargs = {}
|
||||
for name in expected_modules:
|
||||
if name in passed_class_obj:
|
||||
init_kwargs[name] = passed_class_obj[name]
|
||||
else:
|
||||
components = build_sub_model_components(
|
||||
init_kwargs,
|
||||
class_name,
|
||||
name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
image_size=image_size,
|
||||
load_safety_checker=load_safety_checker,
|
||||
local_files_only=local_files_only,
|
||||
**kwargs,
|
||||
)
|
||||
if not components:
|
||||
continue
|
||||
init_kwargs.update(components)
|
||||
|
||||
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
|
||||
if additional_components:
|
||||
init_kwargs.update(additional_components)
|
||||
|
||||
init_kwargs.update(passed_pipe_kwargs)
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into an [`AutoencoderKL`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
|
||||
a VAE from SDXL or a Stable Diffusion v2 model or higher.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
)
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||
kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
checkpoint[key] = f.get_tensor(key)
|
||||
else:
|
||||
checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = yaml.safe_load(config_file)
|
||||
|
||||
# default to sd-v1-5
|
||||
image_size = image_size or 512
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
if scaling_factor is None:
|
||||
if (
|
||||
"model" in original_config
|
||||
and "params" in original_config["model"]
|
||||
and "scale_factor" in original_config["model"]["params"]
|
||||
):
|
||||
vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
else:
|
||||
vae_scaling_factor = 0.18215 # default SD scaling factor
|
||||
|
||||
vae_config["scaling_factor"] = vae_scaling_factor
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
|
||||
else:
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
vae.to(dtype=torch_dtype)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
class FromOriginalControlnetMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
use_linear_projection = kwargs.pop("use_linear_projection", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
image_size = image_size or 512
|
||||
|
||||
controlnet = download_controlnet_from_original_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
original_config_file=config_file,
|
||||
image_size=image_size,
|
||||
extract_ema=extract_ema,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
from_safetensors=from_safetensors,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
controlnet.to(dtype=torch_dtype)
|
||||
|
||||
return controlnet
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -453,3 +453,91 @@ class TextualInversionLoaderMixin:
|
||||
self.enable_sequential_cpu_offload()
|
||||
|
||||
# / Unsafe Code >
|
||||
|
||||
def unload_textual_inversion(
|
||||
self,
|
||||
tokens: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
r"""
|
||||
Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
|
||||
|
||||
Example:
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
# Example 1
|
||||
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
|
||||
pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
|
||||
|
||||
# Remove all token embeddings
|
||||
pipeline.unload_textual_inversion()
|
||||
|
||||
# Example 2
|
||||
pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
|
||||
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
|
||||
|
||||
# Remove just one token
|
||||
pipeline.unload_textual_inversion("<moe-bius>")
|
||||
```
|
||||
"""
|
||||
|
||||
tokenizer = getattr(self, "tokenizer", None)
|
||||
text_encoder = getattr(self, "text_encoder", None)
|
||||
|
||||
# Get textual inversion tokens and ids
|
||||
token_ids = []
|
||||
last_special_token_id = None
|
||||
|
||||
if tokens:
|
||||
if isinstance(tokens, str):
|
||||
tokens = [tokens]
|
||||
for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
|
||||
if not added_token.special:
|
||||
if added_token.content in tokens:
|
||||
token_ids.append(added_token_id)
|
||||
else:
|
||||
last_special_token_id = added_token_id
|
||||
if len(token_ids) == 0:
|
||||
raise ValueError("No tokens to remove found")
|
||||
else:
|
||||
tokens = []
|
||||
for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
|
||||
if not added_token.special:
|
||||
token_ids.append(added_token_id)
|
||||
tokens.append(added_token.content)
|
||||
else:
|
||||
last_special_token_id = added_token_id
|
||||
|
||||
# Delete from tokenizer
|
||||
for token_id, token_to_remove in zip(token_ids, tokens):
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
del tokenizer._added_tokens_encoder[token_to_remove]
|
||||
|
||||
# Make all token ids sequential in tokenizer
|
||||
key_id = 1
|
||||
for token_id in tokenizer.added_tokens_decoder:
|
||||
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
|
||||
token = tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
|
||||
key_id += 1
|
||||
tokenizer._update_trie()
|
||||
|
||||
# Delete from text encoder
|
||||
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
|
||||
temp_text_embedding_weights = text_encoder.get_input_embeddings().weight
|
||||
text_embedding_weights = temp_text_embedding_weights[: last_special_token_id + 1]
|
||||
to_append = []
|
||||
for i in range(last_special_token_id + 1, temp_text_embedding_weights.shape[0]):
|
||||
if i not in token_ids:
|
||||
to_append.append(temp_text_embedding_weights[i].unsqueeze(0))
|
||||
if len(to_append) > 0:
|
||||
to_append = torch.cat(to_append, dim=0)
|
||||
text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0)
|
||||
text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim)
|
||||
text_embeddings_filtered.weight.data = text_embedding_weights
|
||||
text_encoder.set_input_embeddings(text_embeddings_filtered)
|
||||
|
||||
@@ -16,6 +16,7 @@ import os
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
@@ -503,8 +504,9 @@ class UNet2DConditionLoadersMixin:
|
||||
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
self.lora_scale = lora_scale
|
||||
|
||||
@@ -35,23 +35,23 @@ if is_torch_available():
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
_import_structure["prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
|
||||
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
_import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
|
||||
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
|
||||
|
||||
@@ -66,26 +66,31 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsistencyDecoderVAE,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .uvit_2d import UVit2DModel
|
||||
from .transformers import (
|
||||
DualTransformer2DModel,
|
||||
PriorTransformer,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
)
|
||||
from .unets import (
|
||||
Kandinsky3UNet,
|
||||
MotionAdapter,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
UNet3DConditionModel,
|
||||
UNetMotionModel,
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
)
|
||||
from .vq_model import VQModel
|
||||
|
||||
if is_flax_available():
|
||||
from .controlnet_flax import FlaxControlNetModel
|
||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .unets import FlaxUNet2DConditionModel
|
||||
from .vae_flax import FlaxAutoencoderKL
|
||||
|
||||
else:
|
||||
|
||||
@@ -157,7 +157,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
self.use_slicing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -181,7 +181,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -216,7 +216,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -448,7 +448,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
@@ -472,7 +472,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
|
||||
@@ -17,13 +17,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalVAEMixin
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@@ -162,7 +161,7 @@ class TemporalDecoder(nn.Module):
|
||||
return sample
|
||||
|
||||
|
||||
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
@@ -242,7 +241,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -266,7 +265,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -31,7 +31,7 @@ from ..attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unet_2d import UNet2DModel
|
||||
from ..unets.unet_2d import UNet2DModel
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
self.use_slicing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -246,7 +246,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...utils import BaseOutput, is_torch_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import SpatialNorm
|
||||
from ..unet_2d_blocks import (
|
||||
from ..unets.unet_2d_blocks import (
|
||||
AutoencoderTinyBlock,
|
||||
UNetMidBlock2D,
|
||||
get_down_block,
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlnetMixin
|
||||
from ..loaders import FromOriginalControlNetMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -30,8 +30,14 @@ from .attention_processor import (
|
||||
)
|
||||
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
get_down_block,
|
||||
)
|
||||
from .unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -102,7 +108,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
@@ -509,7 +515,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -533,7 +539,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -568,7 +574,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -584,7 +590,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
from .unet_2d_blocks_flax import (
|
||||
from .unets.unet_2d_blocks_flax import (
|
||||
FlaxCrossAttnDownBlock2D,
|
||||
FlaxDownBlock2D,
|
||||
FlaxUNetMidBlock2DCrossAttn,
|
||||
@@ -329,14 +329,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
|
||||
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
plain tuple.
|
||||
train (`bool`, *optional*, defaults to `False`):
|
||||
Use deterministic functions and disable dropout when not training.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
channel_order = self.controlnet_conditioning_channel_order
|
||||
|
||||
@@ -11,145 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
||||
from ..utils import deprecate
|
||||
from .transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
|
||||
|
||||
class DualTransformer2DModel(nn.Module):
|
||||
"""
|
||||
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.transformers = nn.ModuleList(
|
||||
[
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
norm_num_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
sample_size=sample_size,
|
||||
num_vector_embeds=num_vector_embeds,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
# Variables that can be set by a pipeline:
|
||||
|
||||
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
||||
self.mix_ratio = 0.5
|
||||
|
||||
# The shape of `encoder_hidden_states` is expected to be
|
||||
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
||||
self.condition_lengths = [77, 257]
|
||||
|
||||
# Which transformer to use to encode which condition.
|
||||
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
||||
self.transformer_index_for_condition = [1, 0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Optional attention mask to be applied in Attention.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
input_states = hidden_states
|
||||
|
||||
encoded_states = []
|
||||
tokens_start = 0
|
||||
# attention_mask is not used yet
|
||||
for i in range(2):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||
transformer_index = self.transformer_index_for_condition[i]
|
||||
encoded_state = self.transformers[transformer_index](
|
||||
input_states,
|
||||
encoder_hidden_states=condition_state,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
tokens_start += self.condition_lengths[i]
|
||||
|
||||
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
||||
output_states = output_states + input_states
|
||||
|
||||
if not return_dict:
|
||||
return (output_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output_states)
|
||||
class DualTransformer2DModel(DualTransformer2DModel):
|
||||
deprecation_message = "Importing `DualTransformer2DModel` from `diffusers.models.dual_transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel`, instead."
|
||||
deprecate("DualTransformer2DModel", "0.29", deprecation_message)
|
||||
|
||||
@@ -32,6 +32,7 @@ from .. import __version__
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
SAFETENSORS_FILE_EXTENSION,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
_add_variant,
|
||||
@@ -41,7 +42,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.hub_utils import PushToHubMixin
|
||||
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -102,10 +103,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
else:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
else:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
@@ -375,6 +377,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
||||
|
||||
if push_to_hub:
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token)
|
||||
model_card = populate_model_card(model_card)
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
|
||||
@@ -1,380 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
from ..utils import deprecate
|
||||
from .transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
class PriorTransformerOutput(PriorTransformerOutput):
|
||||
deprecation_message = "Importing `PriorTransformerOutput` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformerOutput`, instead."
|
||||
deprecate("PriorTransformerOutput", "0.29", deprecation_message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriorTransformerOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`PriorTransformer`].
|
||||
|
||||
Args:
|
||||
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
||||
"""
|
||||
|
||||
predicted_image_embedding: torch.FloatTensor
|
||||
|
||||
|
||||
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
||||
"""
|
||||
A Prior Transformer model.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
||||
num_embeddings (`int`, *optional*, defaults to 77):
|
||||
The number of embeddings of the model input `hidden_states`
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
||||
The activation function to use to create timestep embeddings.
|
||||
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
||||
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
||||
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
||||
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
||||
needed.
|
||||
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
||||
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
||||
`encoder_hidden_states` is `None`.
|
||||
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
||||
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
||||
product between the text embedding and image embedding as proposed in the unclip paper
|
||||
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
||||
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
||||
If None, will be set to `num_attention_heads * attention_head_dim`
|
||||
embedding_proj_dim (`int`, *optional*, default to None):
|
||||
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
||||
clip_embed_dim (`int`, *optional*, default to None):
|
||||
The dimension of the output. If None, will be set to `embedding_dim`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
num_layers: int = 20,
|
||||
embedding_dim: int = 768,
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
time_embed_act_fn: str = "silu",
|
||||
norm_in_type: Optional[str] = None, # layer
|
||||
embedding_proj_norm_type: Optional[str] = None, # layer
|
||||
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
||||
added_emb_type: Optional[str] = "prd", # prd
|
||||
time_embed_dim: Optional[int] = None,
|
||||
embedding_proj_dim: Optional[int] = None,
|
||||
clip_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
time_embed_dim = time_embed_dim or inner_dim
|
||||
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
||||
clip_embed_dim = clip_embed_dim or embedding_dim
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
if embedding_proj_norm_type is None:
|
||||
self.embedding_proj_norm = None
|
||||
elif embedding_proj_norm_type == "layer":
|
||||
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
||||
|
||||
if encoder_hid_proj_type is None:
|
||||
self.encoder_hidden_states_proj = None
|
||||
elif encoder_hid_proj_type == "linear":
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
if added_emb_type == "prd":
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
elif added_emb_type is None:
|
||||
self.prd_embedding = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn="gelu",
|
||||
attention_bias=True,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if norm_in_type == "layer":
|
||||
self.norm_in = nn.LayerNorm(inner_dim)
|
||||
elif norm_in_type is None:
|
||||
self.norm_in = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
||||
)
|
||||
causal_attention_mask.triu_(1)
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`PriorTransformer`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The currently predicted image embeddings.
|
||||
timestep (`torch.LongTensor`):
|
||||
Current denoising step.
|
||||
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
Projected embedding vector the denoising process is conditioned on.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
||||
Hidden states of the text embeddings the denoising process is conditioned on.
|
||||
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
||||
Text mask for the text embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(hidden_states.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
timesteps_projected = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might be fp16, so we need to cast here.
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
if self.embedding_proj_norm is not None:
|
||||
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
||||
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
additional_embeds = []
|
||||
additional_embeddings_len = 0
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
additional_embeds.append(encoder_hidden_states)
|
||||
additional_embeddings_len += encoder_hidden_states.shape[1]
|
||||
|
||||
if len(proj_embeddings.shape) == 2:
|
||||
proj_embeddings = proj_embeddings[:, None, :]
|
||||
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states[:, None, :]
|
||||
|
||||
additional_embeds = additional_embeds + [
|
||||
proj_embeddings,
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states,
|
||||
]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
additional_embeds.append(prd_embedding)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
additional_embeds,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
||||
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
||||
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
||||
positional_embeddings = F.pad(
|
||||
positional_embeddings,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
additional_embeddings_len,
|
||||
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
||||
),
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
else:
|
||||
hidden_states = hidden_states[:, additional_embeddings_len:]
|
||||
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (predicted_image_embedding,)
|
||||
|
||||
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
||||
|
||||
def post_process_latents(self, prior_latents):
|
||||
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
||||
return prior_latents
|
||||
class PriorTransformer(PriorTransformer):
|
||||
deprecation_message = "Importing `PriorTransformer` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformer`, instead."
|
||||
deprecate("PriorTransformer", "0.29", deprecation_message)
|
||||
|
||||
@@ -11,428 +11,60 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from ..utils import deprecate
|
||||
from .transformers.t5_film_transformer import (
|
||||
DecoderLayer,
|
||||
NewGELUActivation,
|
||||
T5DenseGatedActDense,
|
||||
T5FilmDecoder,
|
||||
T5FiLMLayer,
|
||||
T5LayerCrossAttention,
|
||||
T5LayerFFCond,
|
||||
T5LayerNorm,
|
||||
T5LayerSelfAttentionCond,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .modeling_utils import ModelMixin
|
||||
class T5FilmDecoder(T5FilmDecoder):
|
||||
deprecation_message = "Importing `T5FilmDecoder` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FilmDecoder`, instead."
|
||||
deprecate("T5FilmDecoder", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class T5FilmDecoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
T5 style decoder with FiLM conditioning.
|
||||
class DecoderLayer(DecoderLayer):
|
||||
deprecation_message = "Importing `DecoderLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import DecoderLayer`, instead."
|
||||
deprecate("DecoderLayer", "0.29", deprecation_message)
|
||||
|
||||
Args:
|
||||
input_dims (`int`, *optional*, defaults to `128`):
|
||||
The number of input dimensions.
|
||||
targets_length (`int`, *optional*, defaults to `256`):
|
||||
The length of the targets.
|
||||
d_model (`int`, *optional*, defaults to `768`):
|
||||
Size of the input hidden states.
|
||||
num_layers (`int`, *optional*, defaults to `12`):
|
||||
The number of `DecoderLayer`'s to use.
|
||||
num_heads (`int`, *optional*, defaults to `12`):
|
||||
The number of attention heads to use.
|
||||
d_kv (`int`, *optional*, defaults to `64`):
|
||||
Size of the key-value projection vectors.
|
||||
d_ff (`int`, *optional*, defaults to `2048`):
|
||||
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
|
||||
dropout_rate (`float`, *optional*, defaults to `0.1`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int = 128,
|
||||
targets_length: int = 256,
|
||||
max_decoder_noise_time: float = 2000.0,
|
||||
d_model: int = 768,
|
||||
num_layers: int = 12,
|
||||
num_heads: int = 12,
|
||||
d_kv: int = 64,
|
||||
d_ff: int = 2048,
|
||||
dropout_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
class T5LayerSelfAttentionCond(T5LayerSelfAttentionCond):
|
||||
deprecation_message = "Importing `T5LayerSelfAttentionCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerSelfAttentionCond`, instead."
|
||||
deprecate("T5LayerSelfAttentionCond", "0.29", deprecation_message)
|
||||
|
||||
self.conditioning_emb = nn.Sequential(
|
||||
nn.Linear(d_model, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_model * 4, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_encoding = nn.Embedding(targets_length, d_model)
|
||||
self.position_encoding.weight.requires_grad = False
|
||||
class T5LayerCrossAttention(T5LayerCrossAttention):
|
||||
deprecation_message = "Importing `T5LayerCrossAttention` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerCrossAttention`, instead."
|
||||
deprecate("T5LayerCrossAttention", "0.29", deprecation_message)
|
||||
|
||||
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
class T5LayerFFCond(T5LayerFFCond):
|
||||
deprecation_message = "Importing `T5LayerFFCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerFFCond`, instead."
|
||||
deprecate("T5LayerFFCond", "0.29", deprecation_message)
|
||||
|
||||
self.decoders = nn.ModuleList()
|
||||
for lyr_num in range(num_layers):
|
||||
# FiLM conditional T5 decoder
|
||||
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.decoders.append(lyr)
|
||||
|
||||
self.decoder_norm = T5LayerNorm(d_model)
|
||||
class T5DenseGatedActDense(T5DenseGatedActDense):
|
||||
deprecation_message = "Importing `T5DenseGatedActDense` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5DenseGatedActDense`, instead."
|
||||
deprecate("T5DenseGatedActDense", "0.29", deprecation_message)
|
||||
|
||||
self.post_dropout = nn.Dropout(p=dropout_rate)
|
||||
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
||||
|
||||
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
|
||||
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
||||
return mask.unsqueeze(-3)
|
||||
class T5LayerNorm(T5LayerNorm):
|
||||
deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerNorm`, instead."
|
||||
deprecate("T5LayerNorm", "0.29", deprecation_message)
|
||||
|
||||
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
|
||||
batch, _, _ = decoder_input_tokens.shape
|
||||
assert decoder_noise_time.shape == (batch,)
|
||||
|
||||
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
|
||||
time_steps = get_timestep_embedding(
|
||||
decoder_noise_time * self.config.max_decoder_noise_time,
|
||||
embedding_dim=self.config.d_model,
|
||||
max_period=self.config.max_decoder_noise_time,
|
||||
).to(dtype=self.dtype)
|
||||
class NewGELUActivation(NewGELUActivation):
|
||||
deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import NewGELUActivation`, instead."
|
||||
deprecate("NewGELUActivation", "0.29", deprecation_message)
|
||||
|
||||
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
|
||||
|
||||
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
|
||||
|
||||
seq_length = decoder_input_tokens.shape[1]
|
||||
|
||||
# If we want to use relative positions for audio context, we can just offset
|
||||
# this sequence by the length of encodings_and_masks.
|
||||
decoder_positions = torch.broadcast_to(
|
||||
torch.arange(seq_length, device=decoder_input_tokens.device),
|
||||
(batch, seq_length),
|
||||
)
|
||||
|
||||
position_encodings = self.position_encoding(decoder_positions)
|
||||
|
||||
inputs = self.continuous_inputs_projection(decoder_input_tokens)
|
||||
inputs += position_encodings
|
||||
y = self.dropout(inputs)
|
||||
|
||||
# decoder: No padding present.
|
||||
decoder_mask = torch.ones(
|
||||
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
|
||||
)
|
||||
|
||||
# Translate encoding masks to encoder-decoder masks.
|
||||
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
|
||||
|
||||
# cross attend style: concat encodings
|
||||
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
|
||||
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
|
||||
|
||||
for lyr in self.decoders:
|
||||
y = lyr(
|
||||
y,
|
||||
conditioning_emb=conditioning_emb,
|
||||
encoder_hidden_states=encoded,
|
||||
encoder_attention_mask=encoder_decoder_mask,
|
||||
)[0]
|
||||
|
||||
y = self.decoder_norm(y)
|
||||
y = self.post_dropout(y)
|
||||
|
||||
spec_out = self.spec_out(y)
|
||||
return spec_out
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
r"""
|
||||
T5 decoder layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = nn.ModuleList()
|
||||
|
||||
# cond self attention: layer 0
|
||||
self.layer.append(
|
||||
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
|
||||
)
|
||||
|
||||
# cross attention: layer 1
|
||||
self.layer.append(
|
||||
T5LayerCrossAttention(
|
||||
d_model=d_model,
|
||||
d_kv=d_kv,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=dropout_rate,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
)
|
||||
)
|
||||
|
||||
# Film Cond MLP + dropout: last layer
|
||||
self.layer.append(
|
||||
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias=None,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
hidden_states = self.layer[0](
|
||||
hidden_states,
|
||||
conditioning_emb=conditioning_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
|
||||
encoder_hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states = self.layer[1](
|
||||
hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
|
||||
# Apply Film Conditional Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
|
||||
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
class T5LayerSelfAttentionCond(nn.Module):
|
||||
r"""
|
||||
T5 style self-attention layer with conditioning.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.layer_norm = T5LayerNorm(d_model)
|
||||
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# pre_self_attention_layer_norm
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if conditioning_emb is not None:
|
||||
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
|
||||
|
||||
# Self-attention block
|
||||
attention_output = self.attention(normed_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.dropout(attention_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerCrossAttention(nn.Module):
|
||||
r"""
|
||||
T5 style cross-attention layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
key_value_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.attention(
|
||||
normed_hidden_states,
|
||||
encoder_hidden_states=key_value_states,
|
||||
attention_mask=attention_mask.squeeze(1),
|
||||
)
|
||||
layer_output = hidden_states + self.dropout(attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class T5LayerFFCond(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward conditional layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
forwarded_states = self.layer_norm(hidden_states)
|
||||
if conditioning_emb is not None:
|
||||
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
||||
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
hidden_states = hidden_states + self.dropout(forwarded_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5DenseGatedActDense(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward layer with gated activations and dropout.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wo = nn.Linear(d_ff, d_model, bias=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.act = NewGELUActivation()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_gelu = self.act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
r"""
|
||||
T5 style layer normalization module.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Size of the input hidden states.
|
||||
eps (`float`, `optional`, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||
# half-precision inputs is done in fp32
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class NewGELUActivation(nn.Module):
|
||||
"""
|
||||
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||||
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
||||
|
||||
|
||||
class T5FiLMLayer(nn.Module):
|
||||
"""
|
||||
T5 style FiLM Layer.
|
||||
|
||||
Args:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
super().__init__()
|
||||
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
emb = self.scale_bias(conditioning_emb)
|
||||
scale, shift = torch.chunk(emb, 2, -1)
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
class T5FiLMLayer(T5FiLMLayer):
|
||||
deprecation_message = "Importing `T5FiLMLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FiLMLayer`, instead."
|
||||
deprecate("T5FiLMLayer", "0.29", deprecation_message)
|
||||
|
||||
@@ -11,449 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .modeling_utils import ModelMixin
|
||||
from .normalization import AdaLayerNormSingle
|
||||
from ..utils import deprecate
|
||||
from .transformers.transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
class Transformer2DModelOutput(Transformer2DModelOutput):
|
||||
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput`, instead."
|
||||
deprecate("Transformer2DModelOutput", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A 2D Transformer model for image-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states.
|
||||
|
||||
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
attention_type: str = "default",
|
||||
caption_channels: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
self.is_input_patches = in_channels is not None and patch_size is not None
|
||||
|
||||
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
||||
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
||||
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
||||
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
||||
)
|
||||
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
||||
norm_type = "ada_norm"
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||
)
|
||||
elif self.is_input_vectorized and self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
||||
" sure that either `num_vector_embeds` or `num_patches` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
||||
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
if self.is_input_continuous:
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = linear_cls(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
self.num_vector_embeds = num_vector_embeds
|
||||
self.num_latent_pixels = self.height * self.width
|
||||
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||
)
|
||||
elif self.is_input_patches:
|
||||
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
|
||||
self.patch_size = patch_size
|
||||
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
||||
interpolation_scale = max(interpolation_scale, 1)
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
double_self_attention=double_self_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
attention_type=attention_type,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = linear_cls(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
elif self.is_input_patches and norm_type != "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
elif self.is_input_patches and norm_type == "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
|
||||
# 5. PixArt-Alpha blocks.
|
||||
self.adaln_single = None
|
||||
self.use_additional_conditions = False
|
||||
if norm_type == "ada_norm_single":
|
||||
self.use_additional_conditions = self.config.sample_size == 128
|
||||
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
||||
# additional conditions until we find better name
|
||||
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
||||
|
||||
self.caption_projection = None
|
||||
if caption_channels is not None:
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = hidden_states.shape[0]
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
else:
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if self.is_input_patches:
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
class Transformer2DModel(Transformer2DModel):
|
||||
deprecation_message = "Importing `Transformer2DModel` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModel`, instead."
|
||||
deprecate("Transformer2DModel", "0.29", deprecation_message)
|
||||
|
||||
@@ -11,369 +11,24 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .resnet import AlphaBlender
|
||||
from ..utils import deprecate
|
||||
from .transformers.transformer_temporal import (
|
||||
TransformerSpatioTemporalModel,
|
||||
TransformerTemporalModel,
|
||||
TransformerTemporalModelOutput,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerTemporalModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`TransformerTemporalModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
||||
The hidden states output conditioned on `encoder_hidden_states` input.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
class TransformerTemporalModelOutput(TransformerTemporalModelOutput):
|
||||
deprecation_message = "Importing `TransformerTemporalModelOutput` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModelOutput`, instead."
|
||||
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
||||
activation functions.
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Configure if each `TransformerBlock` should contain two self-attention layers.
|
||||
positional_embeddings: (`str`, *optional*):
|
||||
The type of positional embeddings to apply to the sequence input before passing use.
|
||||
num_positional_embeddings: (`int`, *optional*):
|
||||
The maximum length of the sequence over which to apply positional embeddings.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
double_self_attention: bool = True,
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
double_self_attention=double_self_attention,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=num_positional_embeddings,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
class_labels: torch.LongTensor = None,
|
||||
num_frames: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> TransformerTemporalModelOutput:
|
||||
"""
|
||||
The [`TransformerTemporal`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
num_frames (`int`, *optional*, defaults to 1):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, channel, height, width = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states[None, None, :]
|
||||
.reshape(batch_size, height, width, num_frames, channel)
|
||||
.permute(0, 3, 4, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
class TransformerTemporalModel(TransformerTemporalModel):
|
||||
deprecation_message = "Importing `TransformerTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModel`, instead."
|
||||
deprecate("TransformerTemporalModel", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class TransformerSpatioTemporalModel(nn.Module):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
out_channels (`int`, *optional*):
|
||||
The number of channels in the output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: int = 320,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_mix_inner_dim = inner_dim
|
||||
self.temporal_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TemporalBasicTransformerBlock(
|
||||
inner_dim,
|
||||
time_mix_inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_embed_dim = in_channels * 4
|
||||
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
||||
self.time_proj = Timesteps(in_channels, True, 0)
|
||||
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
# TODO: should use out_channels for continuous projections
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input hidden_states.
|
||||
num_frames (`int`):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
||||
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
||||
images, 0 indicates that the input contains video frames.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, _, height, width = hidden_states.shape
|
||||
num_frames = image_only_indicator.shape[-1]
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
time_context = encoder_hidden_states
|
||||
time_context_first_timestep = time_context[None, :].reshape(
|
||||
batch_size, num_frames, -1, time_context.shape[-1]
|
||||
)[:, 0]
|
||||
time_context = time_context_first_timestep[None, :].broadcast_to(
|
||||
height * width, batch_size, 1, time_context.shape[-1]
|
||||
)
|
||||
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
||||
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
||||
num_frames_emb = num_frames_emb.reshape(-1)
|
||||
t_emb = self.time_proj(num_frames_emb)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
# 2. Blocks
|
||||
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
hidden_states,
|
||||
None,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states_mix = hidden_states
|
||||
hidden_states_mix = hidden_states_mix + emb
|
||||
|
||||
hidden_states_mix = temporal_block(
|
||||
hidden_states_mix,
|
||||
num_frames=num_frames,
|
||||
encoder_hidden_states=time_context,
|
||||
)
|
||||
hidden_states = self.time_mixer(
|
||||
x_spatial=hidden_states,
|
||||
x_temporal=hidden_states_mix,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
class TransformerSpatioTemporalModel(TransformerSpatioTemporalModel):
|
||||
deprecation_message = "Importing `TransformerSpatioTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerSpatioTemporalModel`, instead."
|
||||
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
@@ -0,0 +1,155 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
||||
|
||||
|
||||
class DualTransformer2DModel(nn.Module):
|
||||
"""
|
||||
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.transformers = nn.ModuleList(
|
||||
[
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
norm_num_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
sample_size=sample_size,
|
||||
num_vector_embeds=num_vector_embeds,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
# Variables that can be set by a pipeline:
|
||||
|
||||
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
||||
self.mix_ratio = 0.5
|
||||
|
||||
# The shape of `encoder_hidden_states` is expected to be
|
||||
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
||||
self.condition_lengths = [77, 257]
|
||||
|
||||
# Which transformer to use to encode which condition.
|
||||
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
||||
self.transformer_index_for_condition = [1, 0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Optional attention mask to be applied in Attention.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
input_states = hidden_states
|
||||
|
||||
encoded_states = []
|
||||
tokens_start = 0
|
||||
# attention_mask is not used yet
|
||||
for i in range(2):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||
transformer_index = self.transformer_index_for_condition[i]
|
||||
encoded_state = self.transformers[transformer_index](
|
||||
input_states,
|
||||
encoder_hidden_states=condition_state,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
tokens_start += self.condition_lengths[i]
|
||||
|
||||
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
||||
output_states = output_states + input_states
|
||||
|
||||
if not return_dict:
|
||||
return (output_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output_states)
|
||||
@@ -0,0 +1,380 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriorTransformerOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`PriorTransformer`].
|
||||
|
||||
Args:
|
||||
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
||||
"""
|
||||
|
||||
predicted_image_embedding: torch.FloatTensor
|
||||
|
||||
|
||||
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
||||
"""
|
||||
A Prior Transformer model.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
||||
num_embeddings (`int`, *optional*, defaults to 77):
|
||||
The number of embeddings of the model input `hidden_states`
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
||||
The activation function to use to create timestep embeddings.
|
||||
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
||||
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
||||
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
||||
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
||||
needed.
|
||||
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
||||
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
||||
`encoder_hidden_states` is `None`.
|
||||
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
||||
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
||||
product between the text embedding and image embedding as proposed in the unclip paper
|
||||
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
||||
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
||||
If None, will be set to `num_attention_heads * attention_head_dim`
|
||||
embedding_proj_dim (`int`, *optional*, default to None):
|
||||
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
||||
clip_embed_dim (`int`, *optional*, default to None):
|
||||
The dimension of the output. If None, will be set to `embedding_dim`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
num_layers: int = 20,
|
||||
embedding_dim: int = 768,
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
time_embed_act_fn: str = "silu",
|
||||
norm_in_type: Optional[str] = None, # layer
|
||||
embedding_proj_norm_type: Optional[str] = None, # layer
|
||||
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
||||
added_emb_type: Optional[str] = "prd", # prd
|
||||
time_embed_dim: Optional[int] = None,
|
||||
embedding_proj_dim: Optional[int] = None,
|
||||
clip_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
time_embed_dim = time_embed_dim or inner_dim
|
||||
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
||||
clip_embed_dim = clip_embed_dim or embedding_dim
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
if embedding_proj_norm_type is None:
|
||||
self.embedding_proj_norm = None
|
||||
elif embedding_proj_norm_type == "layer":
|
||||
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
||||
|
||||
if encoder_hid_proj_type is None:
|
||||
self.encoder_hidden_states_proj = None
|
||||
elif encoder_hid_proj_type == "linear":
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
if added_emb_type == "prd":
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
elif added_emb_type is None:
|
||||
self.prd_embedding = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn="gelu",
|
||||
attention_bias=True,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if norm_in_type == "layer":
|
||||
self.norm_in = nn.LayerNorm(inner_dim)
|
||||
elif norm_in_type is None:
|
||||
self.norm_in = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
||||
)
|
||||
causal_attention_mask.triu_(1)
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`PriorTransformer`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The currently predicted image embeddings.
|
||||
timestep (`torch.LongTensor`):
|
||||
Current denoising step.
|
||||
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
Projected embedding vector the denoising process is conditioned on.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
||||
Hidden states of the text embeddings the denoising process is conditioned on.
|
||||
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
||||
Text mask for the text embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(hidden_states.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
timesteps_projected = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might be fp16, so we need to cast here.
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
if self.embedding_proj_norm is not None:
|
||||
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
||||
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
additional_embeds = []
|
||||
additional_embeddings_len = 0
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
additional_embeds.append(encoder_hidden_states)
|
||||
additional_embeddings_len += encoder_hidden_states.shape[1]
|
||||
|
||||
if len(proj_embeddings.shape) == 2:
|
||||
proj_embeddings = proj_embeddings[:, None, :]
|
||||
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states[:, None, :]
|
||||
|
||||
additional_embeds = additional_embeds + [
|
||||
proj_embeddings,
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states,
|
||||
]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
additional_embeds.append(prd_embedding)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
additional_embeds,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
||||
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
||||
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
||||
positional_embeddings = F.pad(
|
||||
positional_embeddings,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
additional_embeddings_len,
|
||||
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
||||
),
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
else:
|
||||
hidden_states = hidden_states[:, additional_embeddings_len:]
|
||||
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (predicted_image_embedding,)
|
||||
|
||||
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
||||
|
||||
def post_process_latents(self, prior_latents):
|
||||
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
||||
return prior_latents
|
||||
@@ -0,0 +1,438 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import get_timestep_embedding
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class T5FilmDecoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
T5 style decoder with FiLM conditioning.
|
||||
|
||||
Args:
|
||||
input_dims (`int`, *optional*, defaults to `128`):
|
||||
The number of input dimensions.
|
||||
targets_length (`int`, *optional*, defaults to `256`):
|
||||
The length of the targets.
|
||||
d_model (`int`, *optional*, defaults to `768`):
|
||||
Size of the input hidden states.
|
||||
num_layers (`int`, *optional*, defaults to `12`):
|
||||
The number of `DecoderLayer`'s to use.
|
||||
num_heads (`int`, *optional*, defaults to `12`):
|
||||
The number of attention heads to use.
|
||||
d_kv (`int`, *optional*, defaults to `64`):
|
||||
Size of the key-value projection vectors.
|
||||
d_ff (`int`, *optional*, defaults to `2048`):
|
||||
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
|
||||
dropout_rate (`float`, *optional*, defaults to `0.1`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int = 128,
|
||||
targets_length: int = 256,
|
||||
max_decoder_noise_time: float = 2000.0,
|
||||
d_model: int = 768,
|
||||
num_layers: int = 12,
|
||||
num_heads: int = 12,
|
||||
d_kv: int = 64,
|
||||
d_ff: int = 2048,
|
||||
dropout_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conditioning_emb = nn.Sequential(
|
||||
nn.Linear(d_model, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_model * 4, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_encoding = nn.Embedding(targets_length, d_model)
|
||||
self.position_encoding.weight.requires_grad = False
|
||||
|
||||
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.decoders = nn.ModuleList()
|
||||
for lyr_num in range(num_layers):
|
||||
# FiLM conditional T5 decoder
|
||||
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.decoders.append(lyr)
|
||||
|
||||
self.decoder_norm = T5LayerNorm(d_model)
|
||||
|
||||
self.post_dropout = nn.Dropout(p=dropout_rate)
|
||||
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
||||
|
||||
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
|
||||
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
||||
return mask.unsqueeze(-3)
|
||||
|
||||
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
|
||||
batch, _, _ = decoder_input_tokens.shape
|
||||
assert decoder_noise_time.shape == (batch,)
|
||||
|
||||
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
|
||||
time_steps = get_timestep_embedding(
|
||||
decoder_noise_time * self.config.max_decoder_noise_time,
|
||||
embedding_dim=self.config.d_model,
|
||||
max_period=self.config.max_decoder_noise_time,
|
||||
).to(dtype=self.dtype)
|
||||
|
||||
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
|
||||
|
||||
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
|
||||
|
||||
seq_length = decoder_input_tokens.shape[1]
|
||||
|
||||
# If we want to use relative positions for audio context, we can just offset
|
||||
# this sequence by the length of encodings_and_masks.
|
||||
decoder_positions = torch.broadcast_to(
|
||||
torch.arange(seq_length, device=decoder_input_tokens.device),
|
||||
(batch, seq_length),
|
||||
)
|
||||
|
||||
position_encodings = self.position_encoding(decoder_positions)
|
||||
|
||||
inputs = self.continuous_inputs_projection(decoder_input_tokens)
|
||||
inputs += position_encodings
|
||||
y = self.dropout(inputs)
|
||||
|
||||
# decoder: No padding present.
|
||||
decoder_mask = torch.ones(
|
||||
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
|
||||
)
|
||||
|
||||
# Translate encoding masks to encoder-decoder masks.
|
||||
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
|
||||
|
||||
# cross attend style: concat encodings
|
||||
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
|
||||
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
|
||||
|
||||
for lyr in self.decoders:
|
||||
y = lyr(
|
||||
y,
|
||||
conditioning_emb=conditioning_emb,
|
||||
encoder_hidden_states=encoded,
|
||||
encoder_attention_mask=encoder_decoder_mask,
|
||||
)[0]
|
||||
|
||||
y = self.decoder_norm(y)
|
||||
y = self.post_dropout(y)
|
||||
|
||||
spec_out = self.spec_out(y)
|
||||
return spec_out
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
r"""
|
||||
T5 decoder layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = nn.ModuleList()
|
||||
|
||||
# cond self attention: layer 0
|
||||
self.layer.append(
|
||||
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
|
||||
)
|
||||
|
||||
# cross attention: layer 1
|
||||
self.layer.append(
|
||||
T5LayerCrossAttention(
|
||||
d_model=d_model,
|
||||
d_kv=d_kv,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=dropout_rate,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
)
|
||||
)
|
||||
|
||||
# Film Cond MLP + dropout: last layer
|
||||
self.layer.append(
|
||||
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias=None,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
hidden_states = self.layer[0](
|
||||
hidden_states,
|
||||
conditioning_emb=conditioning_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
|
||||
encoder_hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states = self.layer[1](
|
||||
hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
|
||||
# Apply Film Conditional Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
|
||||
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
class T5LayerSelfAttentionCond(nn.Module):
|
||||
r"""
|
||||
T5 style self-attention layer with conditioning.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.layer_norm = T5LayerNorm(d_model)
|
||||
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# pre_self_attention_layer_norm
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if conditioning_emb is not None:
|
||||
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
|
||||
|
||||
# Self-attention block
|
||||
attention_output = self.attention(normed_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.dropout(attention_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerCrossAttention(nn.Module):
|
||||
r"""
|
||||
T5 style cross-attention layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
key_value_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.attention(
|
||||
normed_hidden_states,
|
||||
encoder_hidden_states=key_value_states,
|
||||
attention_mask=attention_mask.squeeze(1),
|
||||
)
|
||||
layer_output = hidden_states + self.dropout(attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class T5LayerFFCond(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward conditional layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
forwarded_states = self.layer_norm(hidden_states)
|
||||
if conditioning_emb is not None:
|
||||
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
||||
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
hidden_states = hidden_states + self.dropout(forwarded_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5DenseGatedActDense(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward layer with gated activations and dropout.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wo = nn.Linear(d_ff, d_model, bias=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.act = NewGELUActivation()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_gelu = self.act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
r"""
|
||||
T5 style layer normalization module.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Size of the input hidden states.
|
||||
eps (`float`, `optional`, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||
# half-precision inputs is done in fp32
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class NewGELUActivation(nn.Module):
|
||||
"""
|
||||
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||||
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
||||
|
||||
|
||||
class T5FiLMLayer(nn.Module):
|
||||
"""
|
||||
T5 style FiLM Layer.
|
||||
|
||||
Args:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
super().__init__()
|
||||
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
emb = self.scale_bias(conditioning_emb)
|
||||
scale, shift = torch.chunk(emb, 2, -1)
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
@@ -0,0 +1,458 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A 2D Transformer model for image-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states.
|
||||
|
||||
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
attention_type: str = "default",
|
||||
caption_channels: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
self.is_input_patches = in_channels is not None and patch_size is not None
|
||||
|
||||
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
||||
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
||||
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
||||
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
||||
)
|
||||
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
||||
norm_type = "ada_norm"
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||
)
|
||||
elif self.is_input_vectorized and self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
||||
" sure that either `num_vector_embeds` or `num_patches` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
||||
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
if self.is_input_continuous:
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = linear_cls(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
self.num_vector_embeds = num_vector_embeds
|
||||
self.num_latent_pixels = self.height * self.width
|
||||
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||
)
|
||||
elif self.is_input_patches:
|
||||
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
|
||||
self.patch_size = patch_size
|
||||
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
||||
interpolation_scale = max(interpolation_scale, 1)
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
double_self_attention=double_self_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
attention_type=attention_type,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = linear_cls(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
elif self.is_input_patches and norm_type != "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
elif self.is_input_patches and norm_type == "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
|
||||
# 5. PixArt-Alpha blocks.
|
||||
self.adaln_single = None
|
||||
self.use_additional_conditions = False
|
||||
if norm_type == "ada_norm_single":
|
||||
self.use_additional_conditions = self.config.sample_size == 128
|
||||
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
||||
# additional conditions until we find better name
|
||||
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
||||
|
||||
self.caption_projection = None
|
||||
if caption_channels is not None:
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = hidden_states.shape[0]
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
else:
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if self.is_input_patches:
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -0,0 +1,379 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..resnet import AlphaBlender
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerTemporalModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`TransformerTemporalModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
||||
The hidden states output conditioned on `encoder_hidden_states` input.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
||||
activation functions.
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Configure if each `TransformerBlock` should contain two self-attention layers.
|
||||
positional_embeddings: (`str`, *optional*):
|
||||
The type of positional embeddings to apply to the sequence input before passing use.
|
||||
num_positional_embeddings: (`int`, *optional*):
|
||||
The maximum length of the sequence over which to apply positional embeddings.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
double_self_attention: bool = True,
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
double_self_attention=double_self_attention,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=num_positional_embeddings,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
class_labels: torch.LongTensor = None,
|
||||
num_frames: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> TransformerTemporalModelOutput:
|
||||
"""
|
||||
The [`TransformerTemporal`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
num_frames (`int`, *optional*, defaults to 1):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, channel, height, width = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states[None, None, :]
|
||||
.reshape(batch_size, height, width, num_frames, channel)
|
||||
.permute(0, 3, 4, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
|
||||
|
||||
class TransformerSpatioTemporalModel(nn.Module):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
out_channels (`int`, *optional*):
|
||||
The number of channels in the output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: int = 320,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_mix_inner_dim = inner_dim
|
||||
self.temporal_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TemporalBasicTransformerBlock(
|
||||
inner_dim,
|
||||
time_mix_inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_embed_dim = in_channels * 4
|
||||
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
||||
self.time_proj = Timesteps(in_channels, True, 0)
|
||||
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
# TODO: should use out_channels for continuous projections
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input hidden_states.
|
||||
num_frames (`int`):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
||||
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
||||
images, 0 indicates that the input contains video frames.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, _, height, width = hidden_states.shape
|
||||
num_frames = image_only_indicator.shape[-1]
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
time_context = encoder_hidden_states
|
||||
time_context_first_timestep = time_context[None, :].reshape(
|
||||
batch_size, num_frames, -1, time_context.shape[-1]
|
||||
)[:, 0]
|
||||
time_context = time_context_first_timestep[None, :].broadcast_to(
|
||||
height * width, batch_size, 1, time_context.shape[-1]
|
||||
)
|
||||
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
||||
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
||||
num_frames_emb = num_frames_emb.reshape(-1)
|
||||
t_emb = self.time_proj(num_frames_emb)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
# 2. Blocks
|
||||
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
hidden_states,
|
||||
None,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states_mix = hidden_states
|
||||
hidden_states_mix = hidden_states_mix + emb
|
||||
|
||||
hidden_states_mix = temporal_block(
|
||||
hidden_states_mix,
|
||||
num_frames=num_frames,
|
||||
encoder_hidden_states=time_context,
|
||||
)
|
||||
hidden_states = self.time_mixer(
|
||||
x_spatial=hidden_states,
|
||||
x_temporal=hidden_states_mix,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
@@ -12,244 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
from ..utils import deprecate
|
||||
from .unets.unet_1d import UNet1DModel, UNet1DOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet1DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet1DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
class UNet1DOutput(UNet1DOutput):
|
||||
deprecation_message = "Importing `UNet1DOutput` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DOutput`, instead."
|
||||
deprecate("UNet1DOutput", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
||||
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
||||
extra_in_channels (`int`, *optional*, defaults to 0):
|
||||
Number of additional channels to be added to the input of the first down block. Useful for cases where the
|
||||
input data has more channels than what the model was initially designed for.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
||||
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
|
||||
Tuple of block output channels.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
|
||||
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
|
||||
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
|
||||
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
|
||||
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
|
||||
downsample_each_block (`int`, *optional*, defaults to `False`):
|
||||
Experimental feature for using a UNet without upsampling.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 65536,
|
||||
sample_rate: Optional[int] = None,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
layers_per_block: int = 1,
|
||||
downsample_each_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.sample_size = sample_size
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=timestep_input_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
out_dim=block_out_channels[0],
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
self.out_block = None
|
||||
|
||||
# down
|
||||
output_channel = in_channels
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
if i == 0:
|
||||
input_channel += extra_in_channels
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_downsample=not is_final_block or downsample_each_block,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = get_mid_block(
|
||||
mid_block_type,
|
||||
in_channels=block_out_channels[-1],
|
||||
mid_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
embed_dim=block_out_channels[0],
|
||||
num_layers=layers_per_block,
|
||||
add_downsample=downsample_each_block,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
if out_block_type is None:
|
||||
final_upsample_channels = out_channels
|
||||
else:
|
||||
final_upsample_channels = block_out_channels[0]
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = (
|
||||
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
|
||||
)
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_upsample=not is_final_block,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.out_block = get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=block_out_channels[0],
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet1DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet1DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
timestep_embed = self.time_proj(timesteps)
|
||||
if self.config.use_timestep_embedding:
|
||||
timestep_embed = self.time_mlp(timestep_embed)
|
||||
else:
|
||||
timestep_embed = timestep_embed[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
for downsample_block in self.down_blocks:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 3. mid
|
||||
if self.mid_block:
|
||||
sample = self.mid_block(sample, timestep_embed)
|
||||
|
||||
# 4. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
res_samples = down_block_res_samples[-1:]
|
||||
down_block_res_samples = down_block_res_samples[:-1]
|
||||
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
|
||||
|
||||
# 5. post-process
|
||||
if self.out_block:
|
||||
sample = self.out_block(sample, timestep_embed)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet1DOutput(sample=sample)
|
||||
class UNet1DModel(UNet1DModel):
|
||||
deprecation_message = "Importing `UNet1DModel` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DModel`, instead."
|
||||
deprecate("UNet1DModel", "0.29", deprecation_message)
|
||||
|
||||
@@ -11,616 +11,112 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .activations import get_activation
|
||||
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
||||
|
||||
|
||||
class DownResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
conv_shortcut: bool = False,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_downsample = add_downsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
output_states = ()
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.downsample is not None:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UpResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_upsample = add_upsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if res_hidden_states_tuple is not None:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ValueFunctionMidBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
|
||||
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
|
||||
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
||||
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
x = self.res1(x, temb)
|
||||
x = self.down1(x)
|
||||
x = self.res2(x, temb)
|
||||
x = self.down2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MidResTemporalBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
num_layers: int = 1,
|
||||
add_downsample: bool = False,
|
||||
add_upsample: bool = False,
|
||||
non_linearity: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.add_downsample = add_downsample
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
if self.upsample and self.downsample:
|
||||
raise ValueError("Block cannot downsample and upsample")
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
if self.downsample:
|
||||
self.downsample = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutConv1DBlock(nn.Module):
|
||||
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
|
||||
super().__init__()
|
||||
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
||||
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
||||
self.final_conv1d_act = get_activation(act_fn)
|
||||
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.final_conv1d_1(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_gn(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_act(hidden_states)
|
||||
hidden_states = self.final_conv1d_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutValueFunctionBlock(nn.Module):
|
||||
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
|
||||
super().__init__()
|
||||
self.final_block = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
||||
get_activation(act_fn),
|
||||
nn.Linear(fc_dim // 2, 1),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
||||
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
||||
for layer in self.final_block:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
_kernels = {
|
||||
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
"lanczos3": [
|
||||
0.003689131001010537,
|
||||
0.015056144446134567,
|
||||
-0.03399861603975296,
|
||||
-0.066637322306633,
|
||||
0.13550527393817902,
|
||||
0.44638532400131226,
|
||||
0.44638532400131226,
|
||||
0.13550527393817902,
|
||||
-0.066637322306633,
|
||||
-0.03399861603975296,
|
||||
0.015056144446134567,
|
||||
0.003689131001010537,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv1d(hidden_states, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
|
||||
super().__init__()
|
||||
self.channels = in_channels
|
||||
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
||||
self.num_heads = n_head
|
||||
|
||||
self.query = nn.Linear(self.channels, self.channels)
|
||||
self.key = nn.Linear(self.channels, self.channels)
|
||||
self.value = nn.Linear(self.channels, self.channels)
|
||||
|
||||
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
|
||||
|
||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||
|
||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
||||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
||||
return new_projection
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
batch, channel_dim, seq = hidden_states.shape
|
||||
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
|
||||
|
||||
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
||||
attention_probs = torch.softmax(attention_scores, dim=-1)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.matmul(attention_probs, value_states)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResConvBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
|
||||
super().__init__()
|
||||
self.is_last = is_last
|
||||
self.has_conv_skip = in_channels != out_channels
|
||||
|
||||
if self.has_conv_skip:
|
||||
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
|
||||
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
|
||||
self.gelu_1 = nn.GELU()
|
||||
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
|
||||
|
||||
if not self.is_last:
|
||||
self.group_norm_2 = nn.GroupNorm(1, out_channels)
|
||||
self.gelu_2 = nn.GELU()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
|
||||
|
||||
hidden_states = self.conv_1(hidden_states)
|
||||
hidden_states = self.group_norm_1(hidden_states)
|
||||
hidden_states = self.gelu_1(hidden_states)
|
||||
hidden_states = self.conv_2(hidden_states)
|
||||
from ..utils import deprecate
|
||||
from .unets.unet_1d_blocks import (
|
||||
AttnDownBlock1D,
|
||||
AttnUpBlock1D,
|
||||
DownBlock1D,
|
||||
DownBlock1DNoSkip,
|
||||
DownResnetBlock1D,
|
||||
Downsample1d,
|
||||
MidResTemporalBlock1D,
|
||||
OutConv1DBlock,
|
||||
OutValueFunctionBlock,
|
||||
ResConvBlock,
|
||||
SelfAttention1d,
|
||||
UNetMidBlock1D,
|
||||
UpBlock1D,
|
||||
UpBlock1DNoSkip,
|
||||
UpResnetBlock1D,
|
||||
Upsample1d,
|
||||
ValueFunctionMidBlock1D,
|
||||
)
|
||||
|
||||
if not self.is_last:
|
||||
hidden_states = self.group_norm_2(hidden_states)
|
||||
hidden_states = self.gelu_2(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
return output
|
||||
class DownResnetBlock1D(DownResnetBlock1D):
|
||||
deprecation_message = "Importing `DownResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownResnetBlock1D`, instead."
|
||||
deprecate("DownResnetBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UNetMidBlock1D(nn.Module):
|
||||
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
class UpResnetBlock1D(UpResnetBlock1D):
|
||||
deprecation_message = "Importing `UpResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpResnetBlock1D`, instead."
|
||||
deprecate("UpResnetBlock1D", "0.29", deprecation_message)
|
||||
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
class ValueFunctionMidBlock1D(ValueFunctionMidBlock1D):
|
||||
deprecation_message = "Importing `ValueFunctionMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ValueFunctionMidBlock1D`, instead."
|
||||
deprecate("ValueFunctionMidBlock1D", "0.29", deprecation_message)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
class OutConv1DBlock(OutConv1DBlock):
|
||||
deprecation_message = "Importing `OutConv1DBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutConv1DBlock`, instead."
|
||||
deprecate("OutConv1DBlock", "0.29", deprecation_message)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
class OutValueFunctionBlock(OutValueFunctionBlock):
|
||||
deprecation_message = "Importing `OutValueFunctionBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutValueFunctionBlock`, instead."
|
||||
deprecate("OutValueFunctionBlock", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class AttnDownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
class Downsample1d(Downsample1d):
|
||||
deprecation_message = "Importing `Downsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Downsample1d`, instead."
|
||||
deprecate("Downsample1d", "0.29", deprecation_message)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
class Upsample1d(Upsample1d):
|
||||
deprecation_message = "Importing `Upsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Upsample1d`, instead."
|
||||
deprecate("Upsample1d", "0.29", deprecation_message)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = torch.cat([hidden_states, temb], dim=1)
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class AttnUpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
class SelfAttention1d(SelfAttention1d):
|
||||
deprecation_message = "Importing `SelfAttention1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import SelfAttention1d`, instead."
|
||||
deprecate("SelfAttention1d", "0.29", deprecation_message)
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
class ResConvBlock(ResConvBlock):
|
||||
deprecation_message = "Importing `ResConvBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ResConvBlock`, instead."
|
||||
deprecate("ResConvBlock", "0.29", deprecation_message)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
class UNetMidBlock1D(UNetMidBlock1D):
|
||||
deprecation_message = "Importing `UNetMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UNetMidBlock1D`, instead."
|
||||
deprecate("UNetMidBlock1D", "0.29", deprecation_message)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
class AttnDownBlock1D(AttnDownBlock1D):
|
||||
deprecation_message = "Importing `AttnDownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnDownBlock1D`, instead."
|
||||
deprecate("AttnDownBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
class DownBlock1D(DownBlock1D):
|
||||
deprecation_message = "Importing `DownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1D`, instead."
|
||||
deprecate("DownBlock1D", "0.29", deprecation_message)
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
class DownBlock1DNoSkip(DownBlock1DNoSkip):
|
||||
deprecation_message = "Importing `DownBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1DNoSkip`, instead."
|
||||
deprecate("DownBlock1DNoSkip", "0.29", deprecation_message)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
class AttnUpBlock1D(AttnUpBlock1D):
|
||||
deprecation_message = "Importing `AttnUpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnUpBlock1D`, instead."
|
||||
deprecate("AttnUpBlock1D", "0.29", deprecation_message)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
class UpBlock1D(UpBlock1D):
|
||||
deprecation_message = "Importing `UpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1D`, instead."
|
||||
deprecate("UpBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UpBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
class UpBlock1DNoSkip(UpBlock1DNoSkip):
|
||||
deprecation_message = "Importing `UpBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1DNoSkip`, instead."
|
||||
deprecate("UpBlock1DNoSkip", "0.29", deprecation_message)
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
|
||||
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
|
||||
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
|
||||
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
|
||||
class MidResTemporalBlock1D(MidResTemporalBlock1D):
|
||||
deprecation_message = "Importing `MidResTemporalBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import MidResTemporalBlock1D`, instead."
|
||||
deprecate("MidResTemporalBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
def get_down_block(
|
||||
@@ -630,42 +126,38 @@ def get_down_block(
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
add_downsample: bool,
|
||||
) -> DownBlockType:
|
||||
if down_block_type == "DownResnetBlock1D":
|
||||
return DownResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
):
|
||||
deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_down_block`, instead."
|
||||
deprecate("get_down_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_down_block
|
||||
|
||||
return get_down_block(
|
||||
down_block_type=down_block_type,
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
|
||||
|
||||
def get_up_block(
|
||||
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
|
||||
) -> UpBlockType:
|
||||
if up_block_type == "UpResnetBlock1D":
|
||||
return UpResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
elif up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
):
|
||||
deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_up_block`, instead."
|
||||
deprecate("get_up_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_up_block
|
||||
|
||||
return get_up_block(
|
||||
up_block_type=up_block_type,
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
|
||||
|
||||
def get_mid_block(
|
||||
@@ -676,27 +168,36 @@ def get_mid_block(
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
add_downsample: bool,
|
||||
) -> MidBlockType:
|
||||
if mid_block_type == "MidResTemporalBlock1D":
|
||||
return MidResTemporalBlock1D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif mid_block_type == "ValueFunctionMidBlock1D":
|
||||
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
|
||||
elif mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
):
|
||||
deprecation_message = "Importing `get_mid_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_mid_block`, instead."
|
||||
deprecate("get_mid_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_mid_block
|
||||
|
||||
return get_mid_block(
|
||||
mid_block_type=mid_block_type,
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
mid_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
|
||||
|
||||
def get_out_block(
|
||||
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
|
||||
) -> Optional[OutBlockType]:
|
||||
if out_block_type == "OutConv1DBlock":
|
||||
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
||||
elif out_block_type == "ValueFunction":
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
|
||||
return None
|
||||
):
|
||||
deprecation_message = "Importing `get_out_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_out_block`, instead."
|
||||
deprecate("get_out_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_out_block
|
||||
|
||||
return get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=embed_dim,
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=fc_dim,
|
||||
)
|
||||
|
||||
@@ -11,336 +11,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
from ..utils import deprecate
|
||||
from .unets.unet_2d import UNet2DModel, UNet2DOutput
|
||||
|
||||
|
||||
class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
class UNet2DOutput(UNet2DOutput):
|
||||
deprecation_message = "Importing `UNet2DOutput` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DOutput`, instead."
|
||||
deprecate("UNet2DOutput", "0.29", deprecation_message)
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
||||
1)`.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
||||
Tuple of downsample block types.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
||||
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
||||
Tuple of block output channels.
|
||||
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
||||
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
||||
downsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
||||
upsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
||||
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
|
||||
given number of groups. If left as `None`, the group norm layer will only be created if
|
||||
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
|
||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to `None`):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
|
||||
conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
center_input_sample: bool = False,
|
||||
time_embedding_type: str = "positional",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
||||
layers_per_block: int = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
downsample_padding: int = 1,
|
||||
downsample_type: str = "conv",
|
||||
upsample_type: str = "conv",
|
||||
dropout: float = 0.0,
|
||||
act_fn: str = "silu",
|
||||
attention_head_dim: Optional[int] = 8,
|
||||
norm_num_groups: int = 32,
|
||||
attn_norm_num_groups: Optional[int] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
add_attention: bool = True,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
elif time_embedding_type == "learned":
|
||||
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_groups=attn_norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when doing class conditioning")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
elif self.class_embedding is None and class_labels is not None:
|
||||
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "skip_conv"):
|
||||
sample, res_samples, skip_sample = downsample_block(
|
||||
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "skip_conv"):
|
||||
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
||||
else:
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if skip_sample is not None:
|
||||
sample += skip_sample
|
||||
|
||||
if self.config.time_embedding_type == "fourier":
|
||||
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
||||
sample = sample / timesteps
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DOutput(sample=sample)
|
||||
class UNet2DModel(UNet2DModel):
|
||||
deprecation_message = "Importing `UNet2DModel` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DModel`, instead."
|
||||
deprecate("UNet2DModel", "0.29", deprecation_message)
|
||||
|
||||
+170
-3461
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,16 @@
|
||||
from ...utils import is_flax_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .uvit_2d import UVit2DModel
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
@@ -0,0 +1,255 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet1DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet1DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
||||
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
||||
extra_in_channels (`int`, *optional*, defaults to 0):
|
||||
Number of additional channels to be added to the input of the first down block. Useful for cases where the
|
||||
input data has more channels than what the model was initially designed for.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
||||
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
|
||||
Tuple of block output channels.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
|
||||
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
|
||||
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
|
||||
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
|
||||
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
|
||||
downsample_each_block (`int`, *optional*, defaults to `False`):
|
||||
Experimental feature for using a UNet without upsampling.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 65536,
|
||||
sample_rate: Optional[int] = None,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
layers_per_block: int = 1,
|
||||
downsample_each_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.sample_size = sample_size
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=timestep_input_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
out_dim=block_out_channels[0],
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
self.out_block = None
|
||||
|
||||
# down
|
||||
output_channel = in_channels
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
if i == 0:
|
||||
input_channel += extra_in_channels
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_downsample=not is_final_block or downsample_each_block,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = get_mid_block(
|
||||
mid_block_type,
|
||||
in_channels=block_out_channels[-1],
|
||||
mid_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
embed_dim=block_out_channels[0],
|
||||
num_layers=layers_per_block,
|
||||
add_downsample=downsample_each_block,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
if out_block_type is None:
|
||||
final_upsample_channels = out_channels
|
||||
else:
|
||||
final_upsample_channels = block_out_channels[0]
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = (
|
||||
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
|
||||
)
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_upsample=not is_final_block,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.out_block = get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=block_out_channels[0],
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet1DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet1DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
timestep_embed = self.time_proj(timesteps)
|
||||
if self.config.use_timestep_embedding:
|
||||
timestep_embed = self.time_mlp(timestep_embed)
|
||||
else:
|
||||
timestep_embed = timestep_embed[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
for downsample_block in self.down_blocks:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 3. mid
|
||||
if self.mid_block:
|
||||
sample = self.mid_block(sample, timestep_embed)
|
||||
|
||||
# 4. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
res_samples = down_block_res_samples[-1:]
|
||||
down_block_res_samples = down_block_res_samples[:-1]
|
||||
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
|
||||
|
||||
# 5. post-process
|
||||
if self.out_block:
|
||||
sample = self.out_block(sample, timestep_embed)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet1DOutput(sample=sample)
|
||||
@@ -0,0 +1,702 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..activations import get_activation
|
||||
from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
||||
|
||||
|
||||
class DownResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
conv_shortcut: bool = False,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_downsample = add_downsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
output_states = ()
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.downsample is not None:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UpResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_upsample = add_upsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if res_hidden_states_tuple is not None:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ValueFunctionMidBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
|
||||
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
|
||||
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
||||
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
x = self.res1(x, temb)
|
||||
x = self.down1(x)
|
||||
x = self.res2(x, temb)
|
||||
x = self.down2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MidResTemporalBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
num_layers: int = 1,
|
||||
add_downsample: bool = False,
|
||||
add_upsample: bool = False,
|
||||
non_linearity: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.add_downsample = add_downsample
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
if self.upsample and self.downsample:
|
||||
raise ValueError("Block cannot downsample and upsample")
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
if self.downsample:
|
||||
self.downsample = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutConv1DBlock(nn.Module):
|
||||
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
|
||||
super().__init__()
|
||||
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
||||
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
||||
self.final_conv1d_act = get_activation(act_fn)
|
||||
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.final_conv1d_1(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_gn(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_act(hidden_states)
|
||||
hidden_states = self.final_conv1d_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutValueFunctionBlock(nn.Module):
|
||||
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
|
||||
super().__init__()
|
||||
self.final_block = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
||||
get_activation(act_fn),
|
||||
nn.Linear(fc_dim // 2, 1),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
||||
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
||||
for layer in self.final_block:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
_kernels = {
|
||||
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
"lanczos3": [
|
||||
0.003689131001010537,
|
||||
0.015056144446134567,
|
||||
-0.03399861603975296,
|
||||
-0.066637322306633,
|
||||
0.13550527393817902,
|
||||
0.44638532400131226,
|
||||
0.44638532400131226,
|
||||
0.13550527393817902,
|
||||
-0.066637322306633,
|
||||
-0.03399861603975296,
|
||||
0.015056144446134567,
|
||||
0.003689131001010537,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv1d(hidden_states, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
|
||||
super().__init__()
|
||||
self.channels = in_channels
|
||||
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
||||
self.num_heads = n_head
|
||||
|
||||
self.query = nn.Linear(self.channels, self.channels)
|
||||
self.key = nn.Linear(self.channels, self.channels)
|
||||
self.value = nn.Linear(self.channels, self.channels)
|
||||
|
||||
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
|
||||
|
||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||
|
||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
||||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
||||
return new_projection
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
batch, channel_dim, seq = hidden_states.shape
|
||||
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
|
||||
|
||||
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
||||
attention_probs = torch.softmax(attention_scores, dim=-1)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.matmul(attention_probs, value_states)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResConvBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
|
||||
super().__init__()
|
||||
self.is_last = is_last
|
||||
self.has_conv_skip = in_channels != out_channels
|
||||
|
||||
if self.has_conv_skip:
|
||||
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
|
||||
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
|
||||
self.gelu_1 = nn.GELU()
|
||||
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
|
||||
|
||||
if not self.is_last:
|
||||
self.group_norm_2 = nn.GroupNorm(1, out_channels)
|
||||
self.gelu_2 = nn.GELU()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
|
||||
|
||||
hidden_states = self.conv_1(hidden_states)
|
||||
hidden_states = self.group_norm_1(hidden_states)
|
||||
hidden_states = self.gelu_1(hidden_states)
|
||||
hidden_states = self.conv_2(hidden_states)
|
||||
|
||||
if not self.is_last:
|
||||
hidden_states = self.group_norm_2(hidden_states)
|
||||
hidden_states = self.gelu_2(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
return output
|
||||
|
||||
|
||||
class UNetMidBlock1D(nn.Module):
|
||||
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnDownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = torch.cat([hidden_states, temb], dim=1)
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class AttnUpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
|
||||
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
|
||||
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
|
||||
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
|
||||
|
||||
|
||||
def get_down_block(
|
||||
down_block_type: str,
|
||||
num_layers: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
add_downsample: bool,
|
||||
) -> DownBlockType:
|
||||
if down_block_type == "DownResnetBlock1D":
|
||||
return DownResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(
|
||||
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
|
||||
) -> UpBlockType:
|
||||
if up_block_type == "UpResnetBlock1D":
|
||||
return UpResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
elif up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_mid_block(
|
||||
mid_block_type: str,
|
||||
num_layers: int,
|
||||
in_channels: int,
|
||||
mid_channels: int,
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
add_downsample: bool,
|
||||
) -> MidBlockType:
|
||||
if mid_block_type == "MidResTemporalBlock1D":
|
||||
return MidResTemporalBlock1D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif mid_block_type == "ValueFunctionMidBlock1D":
|
||||
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
|
||||
elif mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_out_block(
|
||||
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
|
||||
) -> Optional[OutBlockType]:
|
||||
if out_block_type == "OutConv1DBlock":
|
||||
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
||||
elif out_block_type == "ValueFunction":
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
|
||||
return None
|
||||
@@ -0,0 +1,346 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
||||
1)`.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
||||
Tuple of downsample block types.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
||||
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
||||
Tuple of block output channels.
|
||||
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
||||
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
||||
downsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
||||
upsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
||||
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
|
||||
given number of groups. If left as `None`, the group norm layer will only be created if
|
||||
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
|
||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to `None`):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
|
||||
conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
center_input_sample: bool = False,
|
||||
time_embedding_type: str = "positional",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
||||
layers_per_block: int = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
downsample_padding: int = 1,
|
||||
downsample_type: str = "conv",
|
||||
upsample_type: str = "conv",
|
||||
dropout: float = 0.0,
|
||||
act_fn: str = "silu",
|
||||
attention_head_dim: Optional[int] = 8,
|
||||
norm_num_groups: int = 32,
|
||||
attn_norm_num_groups: Optional[int] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
add_attention: bool = True,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
elif time_embedding_type == "learned":
|
||||
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_groups=attn_norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when doing class conditioning")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
elif self.class_embedding is None and class_labels is not None:
|
||||
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "skip_conv"):
|
||||
sample, res_samples, skip_sample = downsample_block(
|
||||
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "skip_conv"):
|
||||
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
||||
else:
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if skip_sample is not None:
|
||||
sample += skip_sample
|
||||
|
||||
if self.config.time_embedding_type == "fourier":
|
||||
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
||||
sample = sample / timesteps
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DOutput(sample=sample)
|
||||
File diff suppressed because it is too large
Load Diff
+2
-2
@@ -15,8 +15,8 @@
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from .attention_flax import FlaxTransformer2DModel
|
||||
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
from ..attention_flax import FlaxTransformer2DModel
|
||||
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
|
||||
|
||||
class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
File diff suppressed because it is too large
Load Diff
+7
-7
@@ -19,10 +19,10 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
from ...configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from .unet_2d_blocks_flax import (
|
||||
FlaxCrossAttnDownBlock2D,
|
||||
FlaxCrossAttnUpBlock2D,
|
||||
@@ -342,14 +342,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||
A tensor that if specified is added to the residual of the middle unet block.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
plain tuple.
|
||||
train (`bool`, *optional*, defaults to `False`):
|
||||
Use deterministic functions and disable dropout when not training.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. time
|
||||
+7
-7
@@ -17,19 +17,19 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils import is_torch_version
|
||||
from ..utils.torch_utils import apply_freeu
|
||||
from .attention import Attention
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .resnet import (
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import Attention
|
||||
from ..resnet import (
|
||||
Downsample2D,
|
||||
ResnetBlock2D,
|
||||
SpatioTemporalResBlock,
|
||||
TemporalConvLayer,
|
||||
Upsample2D,
|
||||
)
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import (
|
||||
from ..transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from ..transformers.transformer_2d import Transformer2DModel
|
||||
from ..transformers.transformer_temporal import (
|
||||
TransformerSpatioTemporalModel,
|
||||
TransformerTemporalModel,
|
||||
)
|
||||
+15
-15
@@ -20,20 +20,20 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, deprecate, logging
|
||||
from .activations import get_activation
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, deprecate, logging
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_temporal import TransformerTemporalModel
|
||||
from .unet_3d_blocks import (
|
||||
CrossAttnDownBlock3D,
|
||||
CrossAttnUpBlock3D,
|
||||
@@ -284,7 +284,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -308,7 +308,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -374,7 +374,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -449,7 +449,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -469,7 +469,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
@@ -494,7 +494,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
@@ -503,7 +503,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
|
||||
def unload_lora(self):
|
||||
"""Unloads LoRA weights."""
|
||||
deprecate(
|
||||
+5
-5
@@ -19,11 +19,11 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import Attention, AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
+14
-14
@@ -17,19 +17,19 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import logging
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_temporal import TransformerTemporalModel
|
||||
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_blocks import (
|
||||
@@ -524,7 +524,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -548,7 +548,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -583,7 +583,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
@@ -613,7 +613,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
def disable_forward_chunking(self) -> None:
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
@@ -625,7 +625,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self) -> None:
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -645,7 +645,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
@@ -670,7 +670,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
def disable_freeu(self) -> None:
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
+7
-7
@@ -4,12 +4,12 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
@@ -20,20 +20,20 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import PeftAdapterMixin
|
||||
from .attention import BasicTransformerBlock, SkipFFTransformerBlock
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from .modeling_utils import ModelMixin
|
||||
from .normalization import GlobalResponseNorm, RMSNorm
|
||||
from .resnet import Downsample2D, Upsample2D
|
||||
from ..embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import GlobalResponseNorm, RMSNorm
|
||||
from ..resnet import Downsample2D, Upsample2D
|
||||
|
||||
|
||||
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
@@ -213,7 +213,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
return logits
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -237,7 +237,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -272,7 +272,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -109,7 +109,10 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
|
||||
_import_structure["animatediff"] = ["AnimateDiffPipeline"]
|
||||
_import_structure["animatediff"] = [
|
||||
"AnimateDiffPipeline",
|
||||
"AnimateDiffVideoToVideoPipeline",
|
||||
]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
@@ -341,7 +344,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
|
||||
from .animatediff import AnimateDiffPipeline
|
||||
from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import (
|
||||
AudioLDM2Pipeline,
|
||||
|
||||
@@ -11,7 +11,7 @@ from ...utils import (
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
_import_structure = {"pipeline_output": ["AnimateDiffPipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -21,7 +21,8 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline", "AnimateDiffPipelineOutput"]
|
||||
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
|
||||
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -31,7 +32,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
|
||||
else:
|
||||
from .pipeline_animatediff import AnimateDiffPipeline, AnimateDiffPipelineOutput
|
||||
from .pipeline_animatediff import AnimateDiffPipeline
|
||||
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -26,7 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unet_motion_model import MotionAdapter
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
@@ -37,7 +36,6 @@ from ...schedulers import (
|
||||
)
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -46,6 +44,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -67,10 +66,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -79,6 +75,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -147,11 +152,6 @@ def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> tor
|
||||
return x_mixed
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimateDiffPipelineOutput(BaseOutput):
|
||||
frames: Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
@@ -805,11 +805,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
@@ -0,0 +1,969 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import imageio
|
||||
>>> import requests
|
||||
>>> import torch
|
||||
>>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
|
||||
>>> from diffusers.utils import export_to_gif
|
||||
>>> from io import BytesIO
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
|
||||
>>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter).to("cuda")
|
||||
>>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
|
||||
|
||||
>>> def load_video(file_path: str):
|
||||
... images = []
|
||||
...
|
||||
... if file_path.startswith(('http://', 'https://')):
|
||||
... # If the file_path is a URL
|
||||
... response = requests.get(file_path)
|
||||
... response.raise_for_status()
|
||||
... content = BytesIO(response.content)
|
||||
... vid = imageio.get_reader(content)
|
||||
... else:
|
||||
... # Assuming it's a local file path
|
||||
... vid = imageio.get_reader(file_path)
|
||||
...
|
||||
... for frame in vid:
|
||||
... pil_image = Image.fromarray(frame)
|
||||
... images.append(pil_image)
|
||||
...
|
||||
... return images
|
||||
|
||||
>>> video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
|
||||
>>> output = pipe(video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5)
|
||||
>>> frames = output.frames[0]
|
||||
>>> export_to_gif(frames, "animation.gif")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for video-to-video generation.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
A [`~transformers.CLIPTokenizer`] to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
|
||||
motion_adapter ([`MotionAdapter`]):
|
||||
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
||||
_optional_components = ["feature_extractor", "image_encoder"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
motion_adapter: MotionAdapter,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
):
|
||||
super().__init__()
|
||||
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
motion_adapter=motion_adapter,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
image = self.vae.decode(latents).sample
|
||||
video = (
|
||||
image[None, :]
|
||||
.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_frames,
|
||||
-1,
|
||||
)
|
||||
+ image.shape[2:]
|
||||
)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
strength,
|
||||
height,
|
||||
width,
|
||||
video=None,
|
||||
latents=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if video is not None and latents is not None:
|
||||
raise ValueError("Only one of `video` or `latents` should be provided")
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video,
|
||||
height,
|
||||
width,
|
||||
num_channels_latents,
|
||||
batch_size,
|
||||
timestep,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# video must be a list of list of images
|
||||
# the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
|
||||
# as a list of images
|
||||
if not isinstance(video[0], list):
|
||||
video = [video]
|
||||
if latents is None:
|
||||
video = torch.cat(
|
||||
[self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0
|
||||
)
|
||||
video = video.to(device=device, dtype=dtype)
|
||||
num_frames = video.shape[1]
|
||||
else:
|
||||
num_frames = latents.shape[2]
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
video = video.float()
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
|
||||
]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
|
||||
# restore vae to original dtype
|
||||
if self.vae.config.force_upcast:
|
||||
self.vae.to(dtype)
|
||||
|
||||
init_latents = init_latents.to(dtype)
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
error_message = (
|
||||
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
||||
" images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
if shape != latents.shape:
|
||||
# [B, C, F, H, W]
|
||||
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
video: List[List[PipelineImageInput]] = None,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
strength: float = 0.8,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
video (`List[PipelineImageInput]`):
|
||||
The input video to condition the generation on. Must be a list of images/frames of the video.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated video.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
||||
expense of slower inference.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Higher strength leads to more differences between original video and generated video.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
|
||||
`(batch_size, num_channel, num_frames, height, width)`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
||||
Optional image input to work with IP Adapters.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
|
||||
`np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`AnimateDiffPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
num_videos_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
strength=strength,
|
||||
height=height,
|
||||
width=width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
video=video,
|
||||
latents=latents,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
num_channels_latents=num_channels_latents,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
timestep=latent_timestep,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
# 9. Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return AnimateDiffPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimateDiffPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for AnimateDiff pipelines.
|
||||
|
||||
Args:
|
||||
frames (`List[List[PIL.Image.Image]]` or `torch.Tensor` or `np.ndarray`):
|
||||
List of PIL Images of length `batch_size` or torch.Tensor or np.ndarray of shape
|
||||
`(batch_size, num_frames, height, width, num_channels)`.
|
||||
"""
|
||||
|
||||
frames: Union[List[List[PIL.Image.Image]], torch.Tensor, np.ndarray]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user