Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 97c611402f | |||
| 7c1c705f60 | |||
| 9e72016468 | |||
| 3e9716f22b | |||
| 87bfbc320d | |||
| a517f665a4 | |||
| 16748d1eba | |||
| c9081a8abd | |||
| 0eb68d9ddb | |||
| 9941b3f124 | |||
| 16b9f98b48 | |||
| fee93c81eb | |||
| 5308cce994 | |||
| 318556b20e | |||
| 6620eda357 | |||
| 1f0705adcf | |||
| 5e96333cb2 | |||
| da95a28ff6 | |||
| d66d554dc2 | |||
| c7df846dec | |||
| 8e7bbfbe5a | |||
| e2773c6255 | |||
| ac61eefc9f | |||
| f95615b823 | |||
| a9288b49c9 | |||
| c54419658b | |||
| 6382663dc8 | |||
| 58b8dce129 | |||
| a65ca8a059 | |||
| 5ca062e011 | |||
| 619e3ab6f6 | |||
| 9e2804f720 |
@@ -77,7 +77,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
|
||||
|
||||
## Quickstart
|
||||
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 16000+ checkpoints):
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 19000+ checkpoints):
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -219,7 +219,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
|
||||
- https://github.com/deep-floyd/IF
|
||||
- https://github.com/bentoml/BentoML
|
||||
- https://github.com/bmaltais/kohya_ss
|
||||
- +7000 other amazing GitHub repositories 💪
|
||||
- +8000 other amazing GitHub repositories 💪
|
||||
|
||||
Thank you for using us ❤️.
|
||||
|
||||
|
||||
@@ -228,6 +228,8 @@
|
||||
title: UNet3DConditionModel
|
||||
- local: api/models/unet-motion
|
||||
title: UNetMotionModel
|
||||
- local: api/models/uvit2d
|
||||
title: UViT2DModel
|
||||
- local: api/models/vq
|
||||
title: VQModel
|
||||
- local: api/models/autoencoderkl
|
||||
|
||||
@@ -30,8 +30,8 @@ To learn more about how to load single file weights, see the [Load different Sta
|
||||
|
||||
## FromOriginalVAEMixin
|
||||
|
||||
[[autodoc]] loaders.single_file.FromOriginalVAEMixin
|
||||
[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin
|
||||
|
||||
## FromOriginalControlnetMixin
|
||||
|
||||
[[autodoc]] loaders.single_file.FromOriginalControlnetMixin
|
||||
[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin
|
||||
@@ -33,6 +33,9 @@ model = AutoencoderKL.from_single_file(url)
|
||||
## AutoencoderKL
|
||||
|
||||
[[autodoc]] AutoencoderKL
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNetMotionModel
|
||||
|
||||
## UNet3DConditionOutput
|
||||
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
|
||||
[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet1DModel
|
||||
|
||||
## UNet1DOutput
|
||||
[[autodoc]] models.unet_1d.UNet1DOutput
|
||||
[[autodoc]] models.unets.unet_1d.UNet1DOutput
|
||||
|
||||
@@ -22,10 +22,10 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet2DConditionModel
|
||||
|
||||
## UNet2DConditionOutput
|
||||
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
|
||||
[[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput
|
||||
|
||||
## FlaxUNet2DConditionModel
|
||||
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel
|
||||
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionModel
|
||||
|
||||
## FlaxUNet2DConditionOutput
|
||||
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
|
||||
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet2DModel
|
||||
|
||||
## UNet2DOutput
|
||||
[[autodoc]] models.unet_2d.UNet2DOutput
|
||||
[[autodoc]] models.unets.unet_2d.UNet2DOutput
|
||||
|
||||
@@ -22,4 +22,4 @@ The abstract from the paper is:
|
||||
[[autodoc]] UNet3DConditionModel
|
||||
|
||||
## UNet3DConditionOutput
|
||||
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
|
||||
[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# UVit2DModel
|
||||
|
||||
The [U-ViT](https://hf.co/papers/2301.11093) model is a vision transformer (ViT) based UNet. This model incorporates elements from ViT (considers all inputs such as time, conditions and noisy image patches as tokens) and a UNet (long skip connections between the shallow and deep layers). The skip connection is important for predicting pixel-level features. An additional 3x3 convolutional block is applied prior to the final output to improve image quality.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Currently, applying diffusion models in pixel space of high resolution images is difficult. Instead, existing approaches focus on diffusion in lower dimensional spaces (latent diffusion), or have multiple super-resolution levels of generation referred to as cascades. The downside is that these approaches add additional complexity to the diffusion framework. This paper aims to improve denoising diffusion for high resolution images while keeping the model as simple as possible. The paper is centered around the research question: How can one train a standard denoising diffusion models on high resolution images, and still obtain performance comparable to these alternate approaches? The four main findings are: 1) the noise schedule should be adjusted for high resolution images, 2) It is sufficient to scale only a particular part of the architecture, 3) dropout should be added at specific locations in the architecture, and 4) downsampling is an effective strategy to avoid high resolution feature maps. Combining these simple yet effective techniques, we achieve state-of-the-art on image generation among diffusion models without sampling modifiers on ImageNet.*
|
||||
|
||||
## UVit2DModel
|
||||
|
||||
[[autodoc]] UVit2DModel
|
||||
|
||||
## UVit2DConvEmbed
|
||||
|
||||
[[autodoc]] models.unets.uvit_2d.UVit2DConvEmbed
|
||||
|
||||
## UVitBlock
|
||||
|
||||
[[autodoc]] models.unets.uvit_2d.UVitBlock
|
||||
|
||||
## ConvNextBlock
|
||||
|
||||
[[autodoc]] models.unets.uvit_2d.ConvNextBlock
|
||||
|
||||
## ConvMlmLayer
|
||||
|
||||
[[autodoc]] models.unets.uvit_2d.ConvMlmLayer
|
||||
@@ -25,6 +25,7 @@ The abstract of the paper is the following:
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
|
||||
| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
@@ -32,6 +33,8 @@ Motion Adapter checkpoints can be found under [guoyww](https://huggingface.co/gu
|
||||
|
||||
## Usage example
|
||||
|
||||
### AnimateDiffPipeline
|
||||
|
||||
AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet.
|
||||
|
||||
The following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
|
||||
@@ -98,6 +101,114 @@ AnimateDiff tends to work better with finetuned Stable Diffusion models. If you
|
||||
|
||||
</Tip>
|
||||
|
||||
### AnimateDiffVideoToVideoPipeline
|
||||
|
||||
AnimateDiff can also be used to generate visually similar videos or enable style/character/background or other edits starting from an initial video, allowing you to seamlessly explore creative possibilities.
|
||||
|
||||
```python
|
||||
import imageio
|
||||
import requests
|
||||
import torch
|
||||
from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
|
||||
from diffusers.utils import export_to_gif
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
# Load the motion adapter
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
|
||||
# load SD 1.5 based finetuned model
|
||||
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
|
||||
pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to("cuda")
|
||||
scheduler = DDIMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
clip_sample=False,
|
||||
timestep_spacing="linspace",
|
||||
beta_schedule="linear",
|
||||
steps_offset=1,
|
||||
)
|
||||
pipe.scheduler = scheduler
|
||||
|
||||
# enable memory savings
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# helper function to load videos
|
||||
def load_video(file_path: str):
|
||||
images = []
|
||||
|
||||
if file_path.startswith(('http://', 'https://')):
|
||||
# If the file_path is a URL
|
||||
response = requests.get(file_path)
|
||||
response.raise_for_status()
|
||||
content = BytesIO(response.content)
|
||||
vid = imageio.get_reader(content)
|
||||
else:
|
||||
# Assuming it's a local file path
|
||||
vid = imageio.get_reader(file_path)
|
||||
|
||||
for frame in vid:
|
||||
pil_image = Image.fromarray(frame)
|
||||
images.append(pil_image)
|
||||
|
||||
return images
|
||||
|
||||
video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
|
||||
|
||||
output = pipe(
|
||||
video = video,
|
||||
prompt="panda playing a guitar, on a boat, in the ocean, high quality",
|
||||
negative_prompt="bad quality, worse quality",
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=25,
|
||||
strength=0.5,
|
||||
generator=torch.Generator("cpu").manual_seed(42),
|
||||
)
|
||||
frames = output.frames[0]
|
||||
export_to_gif(frames, "animation.gif")
|
||||
```
|
||||
|
||||
Here are some sample outputs:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th align=center>Source Video</th>
|
||||
<th align=center>Output Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=center>
|
||||
raccoon playing a guitar
|
||||
<br />
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
|
||||
alt="racoon playing a guitar"
|
||||
style="width: 300px;" />
|
||||
</td>
|
||||
<td align=center>
|
||||
panda playing a guitar
|
||||
<br/>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-output-1.gif"
|
||||
alt="panda playing a guitar"
|
||||
style="width: 300px;" />
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align=center>
|
||||
closeup of margot robbie, fireworks in the background, high quality
|
||||
<br />
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-2.gif"
|
||||
alt="closeup of margot robbie, fireworks in the background, high quality"
|
||||
style="width: 300px;" />
|
||||
</td>
|
||||
<td align=center>
|
||||
closeup of tony stark, robert downey jr, fireworks
|
||||
<br/>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-output-2.gif"
|
||||
alt="closeup of tony stark, robert downey jr, fireworks"
|
||||
style="width: 300px;" />
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Using Motion LoRAs
|
||||
|
||||
Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations.
|
||||
@@ -300,16 +411,14 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
## AnimateDiffPipeline
|
||||
|
||||
[[autodoc]] AnimateDiffPipeline
|
||||
- all
|
||||
- __call__
|
||||
- enable_freeu
|
||||
- disable_freeu
|
||||
- enable_free_init
|
||||
- disable_free_init
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_vae_tiling
|
||||
- disable_vae_tiling
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AnimateDiffVideoToVideoPipeline
|
||||
|
||||
[[autodoc]] AnimateDiffVideoToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AnimateDiffPipelineOutput
|
||||
|
||||
|
||||
@@ -11,4 +11,6 @@
|
||||
- sections:
|
||||
- local: tutorials/tutorial_overview
|
||||
title: 概要
|
||||
- local: tutorials/autopipeline
|
||||
title: AutoPipeline
|
||||
title: チュートリアル
|
||||
@@ -0,0 +1,168 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AutoPipeline
|
||||
|
||||
Diffusersは様々なタスクをこなすことができ、テキストから画像、画像から画像、画像の修復など、複数のタスクに対して同じように事前学習された重みを再利用することができます。しかし、ライブラリや拡散モデルに慣れていない場合、どのタスクにどのパイプラインを使えばいいのかがわかりにくいかもしれません。例えば、 [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) チェックポイントをテキストから画像に変換するために使用している場合、それぞれ[`StableDiffusionImg2ImgPipeline`]クラスと[`StableDiffusionInpaintPipeline`]クラスでチェックポイントをロードすることで、画像から画像や画像の修復にも使えることを知らない可能性もあります。
|
||||
|
||||
`AutoPipeline` クラスは、🤗 Diffusers の様々なパイプラインをよりシンプルするために設計されています。この汎用的でタスク重視のパイプラインによってタスクそのものに集中することができます。`AutoPipeline` は、使用するべき正しいパイプラインクラスを自動的に検出するため、特定のパイプラインクラス名を知らなくても、タスクのチェックポイントを簡単にロードできます。
|
||||
|
||||
<Tip>
|
||||
|
||||
どのタスクがサポートされているかは、[AutoPipeline](../api/pipelines/auto_pipeline) のリファレンスをご覧ください。現在、text-to-image、image-to-image、inpaintingをサポートしています。
|
||||
|
||||
</Tip>
|
||||
|
||||
このチュートリアルでは、`AutoPipeline` を使用して、事前に学習された重みが与えられたときに、特定のタスクを読み込むためのパイプラインクラスを自動的に推測する方法を示します。
|
||||
|
||||
## タスクに合わせてAutoPipeline を選択する
|
||||
まずはチェックポイントを選ぶことから始めましょう。例えば、 [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) チェックポイントでテキストから画像への変換したいなら、[`AutoPipelineForText2Image`]を使います:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
prompt = "peasant and dragon combat, wood cutting style, viking era, bevel with rune"
|
||||
|
||||
image = pipeline(prompt, num_inference_steps=25).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png" alt="generated image of peasant fighting dragon in wood cutting style"/>
|
||||
</div>
|
||||
|
||||
[`AutoPipelineForText2Image`] を具体的に見ていきましょう:
|
||||
|
||||
1. [`model_index.json`](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json) ファイルから `"stable-diffusion"` クラスを自動的に検出します。
|
||||
2. `"stable-diffusion"` のクラス名に基づいて、テキストから画像へ変換する [`StableDiffusionPipeline`] を読み込みます。
|
||||
|
||||
同様に、画像から画像へ変換する場合、[`AutoPipelineForImage2Image`] は `model_index.json` ファイルから `"stable-diffusion"` チェックポイントを検出し、対応する [`StableDiffusionImg2ImgPipeline`] を読み込みます。また、入力画像にノイズの量やバリエーションの追加を決めるための強さなど、パイプラインクラスに固有の追加引数を渡すこともできます:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
import torch
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
prompt = "a portrait of a dog wearing a pearl earring"
|
||||
|
||||
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0f/1665_Girl_with_a_Pearl_Earring.jpg/800px-1665_Girl_with_a_Pearl_Earring.jpg"
|
||||
|
||||
response = requests.get(url)
|
||||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
image.thumbnail((768, 768))
|
||||
|
||||
image = pipeline(prompt, image, num_inference_steps=200, strength=0.75, guidance_scale=10.5).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png" alt="generated image of a vermeer portrait of a dog wearing a pearl earring"/>
|
||||
</div>
|
||||
|
||||
また、画像の修復を行いたい場合は、 [`AutoPipelineForInpainting`] が、同様にベースとなる[`StableDiffusionInpaintPipeline`]クラスを読み込みます:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = load_image(img_url).convert("RGB")
|
||||
mask_image = load_image(mask_url).convert("RGB")
|
||||
|
||||
prompt = "A majestic tiger sitting on a bench"
|
||||
image = pipeline(prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png" alt="generated image of a tiger sitting on a bench"/>
|
||||
</div>
|
||||
|
||||
サポートされていないチェックポイントを読み込もうとすると、エラーになります:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"openai/shap-e-img2img", torch_dtype=torch.float16, use_safetensors=True
|
||||
)
|
||||
"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
|
||||
```
|
||||
|
||||
## 複数のパイプラインを使用する
|
||||
|
||||
いくつかのワークフローや多くのパイプラインを読み込む場合、不要なメモリを使ってしまう再読み込みをするよりも、チェックポイントから同じコンポーネントを再利用する方がメモリ効率が良いです。たとえば、テキストから画像への変換にチェックポイントを使い、画像から画像への変換にまたチェックポイントを使いたい場合、[from_pipe()](https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe) メソッドを使用します。このメソッドは、以前読み込まれたパイプラインのコンポーネントを使うことで追加のメモリを消費することなく、新しいパイプラインを作成します。
|
||||
|
||||
[from_pipe()](https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe) メソッドは、元のパイプラインクラスを検出し、実行したいタスクに対応する新しいパイプラインクラスにマッピングします。例えば、テキストから画像への`"stable-diffusion"` クラスのパイプラインを読み込む場合:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||||
import torch
|
||||
|
||||
pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
)
|
||||
print(type(pipeline_text2img))
|
||||
"<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'>"
|
||||
```
|
||||
|
||||
そして、[from_pipe()] (https://huggingface.co/docs/diffusers/v0.25.1/en/api/pipelines/auto_pipeline#diffusers.AutoPipelineForImage2Image.from_pipe)は、もとの`"stable-diffusion"` パイプラインのクラスである [`StableDiffusionImg2ImgPipeline`] にマップします:
|
||||
|
||||
```py
|
||||
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)
|
||||
print(type(pipeline_img2img))
|
||||
"<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline'>"
|
||||
```
|
||||
元のパイプラインにオプションとして引数(セーフティチェッカーの無効化など)を渡した場合、この引数も新しいパイプラインに渡されます:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
|
||||
import torch
|
||||
|
||||
pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
requires_safety_checker=False,
|
||||
).to("cuda")
|
||||
|
||||
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)
|
||||
print(pipeline_img2img.config.requires_safety_checker)
|
||||
"False"
|
||||
```
|
||||
|
||||
新しいパイプラインの動作を変更したい場合は、元のパイプラインの引数や設定を上書きすることができます。例えば、セーフティチェッカーをオンに戻し、`strength` 引数を追加します:
|
||||
|
||||
```py
|
||||
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img, requires_safety_checker=True, strength=0.3)
|
||||
print(pipeline_img2img.config.requires_safety_checker)
|
||||
"True"
|
||||
```
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1279,7 +1279,7 @@ def main(args):
|
||||
for name, param in text_encoder_one.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param = param.to(dtype=torch.float32)
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_one.append(param)
|
||||
else:
|
||||
@@ -1288,7 +1288,7 @@ def main(args):
|
||||
for name, param in text_encoder_two.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param = param.to(dtype=torch.float32)
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_two.append(param)
|
||||
else:
|
||||
@@ -1725,19 +1725,19 @@ def main(args):
|
||||
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
|
||||
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
|
||||
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
|
||||
|
||||
# flag used for textual inversion
|
||||
pivoted = False
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
# if performing any kind of optimization of text_encoder params
|
||||
if args.train_text_encoder or args.train_text_encoder_ti:
|
||||
if epoch == num_train_epochs_text_encoder:
|
||||
print("PIVOT HALFWAY", epoch)
|
||||
# stopping optimization of text_encoder params
|
||||
# re setting the optimizer to optimize only on unet params
|
||||
optimizer.param_groups[1]["lr"] = 0.0
|
||||
optimizer.param_groups[2]["lr"] = 0.0
|
||||
# this flag is used to reset the optimizer to optimize only on unet params
|
||||
pivoted = True
|
||||
|
||||
else:
|
||||
# still optimizng the text encoder
|
||||
# still optimizing the text encoder
|
||||
text_encoder_one.train()
|
||||
text_encoder_two.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
@@ -1747,6 +1747,12 @@ def main(args):
|
||||
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if pivoted:
|
||||
# stopping optimization of text_encoder params
|
||||
# re setting the optimizer to optimize only on unet params
|
||||
optimizer.param_groups[1]["lr"] = 0.0
|
||||
optimizer.param_groups[2]["lr"] = 0.0
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
prompts = batch["prompts"]
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
@@ -1885,8 +1891,7 @@ def main(args):
|
||||
|
||||
# every step, we reset the embeddings to the original embeddings.
|
||||
if args.train_text_encoder_ti:
|
||||
for idx, text_encoder in enumerate(text_encoders):
|
||||
embedding_handler.retract_embeddings()
|
||||
embedding_handler.retract_embeddings()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -27,8 +27,8 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
|
||||
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
|
||||
MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
|
||||
| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
|
||||
| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
|
||||
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) |
|
||||
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
@@ -44,9 +44,9 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
|
||||
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
|
||||
| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
|
||||
FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
|
||||
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
|
||||
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
|
||||
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
@@ -55,12 +55,16 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
|
||||
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
|
||||
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
|
||||
| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
|
||||
| Instaflow Pipeline | Implementation of [InstaFlow! One-Step Stable Diffusion with Rectified Flow](https://arxiv.org/abs/2309.06380) | [Instaflow Pipeline](#instaflow-pipeline) | - | [Ayush Mangal](https://github.com/ayushtues) |
|
||||
| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
|
||||
| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender_A_Video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
|
||||
| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
|
||||
| InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="filename_in_the_community_folder")
|
||||
```
|
||||
@@ -2990,7 +2994,7 @@ pipe = DiffusionPipeline.from_pretrained(
|
||||
custom_pipeline="pipeline_animatediff_controlnet",
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
|
||||
model_id, subfolder="scheduler", beta_schedule="linear", clip_sample=False, timestep_spacing="linspace", steps_offset=1
|
||||
)
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
@@ -3226,6 +3230,43 @@ output_image.save("./output.png")
|
||||
|
||||
```
|
||||
|
||||
### Instaflow Pipeline
|
||||
InstaFlow is an ultra-fast, one-step image generator that achieves image quality close to Stable Diffusion, significantly reducing the demand of computational resources. This efficiency is made possible through a recent [Rectified Flow](https://github.com/gnobitab/RectifiedFlow) technique, which trains probability flows with straight trajectories, hence inherently requiring only a single step for fast inference.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("XCLIU/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float16, custom_pipeline="instaflow_one_step")
|
||||
pipe.to("cuda") ### if GPU is not available, comment this line
|
||||
prompt = "A hyper-realistic photo of a cute cat."
|
||||
|
||||
images = pipe(prompt=prompt,
|
||||
num_inference_steps=1,
|
||||
guidance_scale=0.0).images
|
||||
images[0].save("./image.png")
|
||||
```
|
||||

|
||||
|
||||
You can also combine it with LORA out of the box, like https://huggingface.co/artificialguybr/logo-redmond-1-5v-logo-lora-for-liberteredmond-sd-1-5, to unlock cool use cases in single step!
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("XCLIU/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float16, custom_pipeline="instaflow_one_step")
|
||||
pipe.to("cuda") ### if GPU is not available, comment this line
|
||||
pipe.load_lora_weights("artificialguybr/logo-redmond-1-5v-logo-lora-for-liberteredmond-sd-1-5")
|
||||
prompt = "logo, A logo for a fitness app, dynamic running figure, energetic colors (red, orange) ),LogoRedAF ,"
|
||||
images = pipe(prompt=prompt,
|
||||
num_inference_steps=1,
|
||||
guidance_scale=0.0).images
|
||||
images[0].save("./image.png")
|
||||
```
|
||||

|
||||
|
||||
### Null-Text Inversion pipeline
|
||||
|
||||
This pipeline provides null-text inversion for editing real images. It enables null-text optimization, and DDIM reconstruction via w, w/o null-text optimization. No prompt-to-prompt code is implemented as there is a Prompt2PromptPipeline.
|
||||
@@ -3263,8 +3304,10 @@ pipeline = NullTextPipeline.from_pretrained(model_path, scheduler = scheduler, t
|
||||
inverted_latent, uncond = pipeline.invert(input_image, invert_prompt, num_inner_steps=10, early_stop_epsilon= 1e-5, num_inference_steps = steps)
|
||||
pipeline(prompt, uncond, inverted_latent, guidance_scale=7.5, num_inference_steps=steps).images[0].save(input_image+".output.jpg")
|
||||
```
|
||||
|
||||
### Rerender_A_Video
|
||||
|
||||
```
|
||||
This is the Diffusers implementation of zero-shot video-to-video translation pipeline [Rerender_A_Video](https://github.com/williamyang1991/Rerender_A_Video) (without Ebsynth postprocessing). To run the code, please install gmflow. Then modify the path in `examples/community/rerender_a_video.py`:
|
||||
|
||||
```py
|
||||
@@ -3332,7 +3375,6 @@ generator = torch.manual_seed(0)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
output_frames = pipe(
|
||||
"a beautiful woman in CG style, best quality, extremely detailed",
|
||||
|
||||
frames,
|
||||
control_frames,
|
||||
num_inference_steps=20,
|
||||
@@ -3353,7 +3395,7 @@ export_to_video(
|
||||
|
||||
### StyleAligned Pipeline
|
||||
|
||||
This pipeline is the implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133).
|
||||
This pipeline is the implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133). You can find more results [here](https://github.com/huggingface/diffusers/pull/6489#issuecomment-1881209354).
|
||||
|
||||
> Large-scale Text-to-Image (T2I) models have rapidly gained prominence across creative fields, generating visually compelling outputs from textual prompts. However, controlling these models to ensure consistent style remains challenging, with existing methods necessitating fine-tuning and manual intervention to disentangle content and style. In this paper, we introduce StyleAligned, a novel technique designed to establish style alignment among a series of generated images. By employing minimal `attention sharing' during the diffusion process, our method maintains style consistency across images within T2I models. This approach allows for the creation of style-consistent images using a reference style through a straightforward inversion operation. Our method's evaluation across diverse styles and text prompts demonstrates high-quality synthesis and fidelity, underscoring its efficacy in achieving consistent style across various inputs.
|
||||
|
||||
@@ -3409,11 +3451,37 @@ images = pipe(
|
||||
pipe.disable_style_aligned()
|
||||
```
|
||||
|
||||
### AnimateDiff Image-To-Video Pipeline
|
||||
|
||||
This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
|
||||
from diffusers.utils import export_to_gif, load_image
|
||||
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
|
||||
pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
|
||||
pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
|
||||
|
||||
image = load_image("snail.png")
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt="A snail moving on the ground",
|
||||
strength=0.8,
|
||||
latent_interpolation_method="slerp", # can be lerp, slerp, or your own callback
|
||||
)
|
||||
frames = output.frames[0]
|
||||
export_to_gif(frames, "animation.gif")
|
||||
```
|
||||
|
||||
### IP Adapter Face ID
|
||||
|
||||
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
|
||||
You need to install `insightface` and all its requirements to use this model.
|
||||
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
|
||||
You have to disable PEFT BACKEND in order to load weights.
|
||||
You can find more results [here](https://github.com/huggingface/diffusers/pull/6276).
|
||||
|
||||
```py
|
||||
import diffusers
|
||||
@@ -3467,3 +3535,73 @@ images = pipeline(
|
||||
for i in range(num_images):
|
||||
images[i].save(f"c{i}.png")
|
||||
```
|
||||
|
||||
### InstantID Pipeline
|
||||
|
||||
InstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving generation with only single image, supporting various downstream tasks. For any usgae question, please refer to the [official implementation](https://github.com/InstantID/InstantID).
|
||||
|
||||
```py
|
||||
# !pip install opencv-python transformers accelerate insightface
|
||||
import diffusers
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.models import ControlNetModel
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface.app import FaceAnalysis
|
||||
from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
|
||||
|
||||
# prepare 'antelopev2' under ./models
|
||||
# https://github.com/deepinsight/insightface/issues/1896#issuecomment-1023867304
|
||||
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
|
||||
# prepare models under ./checkpoints
|
||||
# https://huggingface.co/InstantX/InstantID
|
||||
from huggingface_hub import hf_hub_download
|
||||
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
|
||||
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
|
||||
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
|
||||
|
||||
face_adapter = f'./checkpoints/ip-adapter.bin'
|
||||
controlnet_path = f'./checkpoints/ControlNetModel'
|
||||
|
||||
# load IdentityNet
|
||||
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
||||
|
||||
base_model = 'wangqixun/YamerMIX_v8'
|
||||
pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe.cuda()
|
||||
|
||||
# load adapter
|
||||
pipe.load_ip_adapter_instantid(face_adapter)
|
||||
|
||||
# load an image
|
||||
face_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
|
||||
|
||||
# prepare face emb
|
||||
face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
|
||||
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
|
||||
face_emb = face_info['embedding']
|
||||
face_kps = draw_kps(face_image, face_info['kps'])
|
||||
|
||||
# prompt
|
||||
prompt = "film noir style, ink sketch|vector, male man, highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic"
|
||||
negative_prompt = "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, vibrant, colorful"
|
||||
|
||||
# generate image
|
||||
pipe.set_ip_adapter_scale(0.8)
|
||||
image = pipe(
|
||||
prompt,
|
||||
image_embeds=face_emb,
|
||||
image=face_kps,
|
||||
controlnet_conditioning_scale=0.8,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
@@ -0,0 +1,707 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
deprecate,
|
||||
logging,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class InstaFlowPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Rectified Flow and Euler discretization.
|
||||
This customized pipeline is based on StableDiffusionPipeline from the official Diffusers library (0.21.4)
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.CLIPTextModel`]):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer ([`~transformers.CLIPTokenizer`]):
|
||||
A `CLIPTokenizer` to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
|
||||
about a model's potential harms.
|
||||
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
prompt_embeds_tuple = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# concatenate for backwards comp
|
||||
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
||||
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def merge_dW_to_unet(pipe, dW_dict, alpha=1.0):
|
||||
_tmp_sd = pipe.unet.state_dict()
|
||||
for key in dW_dict.keys():
|
||||
_tmp_sd[key] += dW_dict[key] * alpha
|
||||
pipe.unet.load_state_dict(_tmp_sd, strict=False)
|
||||
return pipe
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
||||
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps = [(1.0 - i / num_inference_steps) * 1000.0 for i in range(num_inference_steps)]
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
dt = 1.0 / num_inference_steps
|
||||
|
||||
# 7. Denoising loop of Euler discretization from t = 0 to t = 1
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
vec_t = torch.ones((latent_model_input.shape[0],), device=latents.device) * t
|
||||
|
||||
v_pred = self.unet(latent_model_input, vec_t, encoder_hidden_states=prompt_embeds).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
v_pred_neg, v_pred_text = v_pred.chunk(2)
|
||||
v_pred = v_pred_neg + guidance_scale * (v_pred_text - v_pred_neg)
|
||||
|
||||
latents = latents + dt * v_pred
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -26,7 +26,7 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.models.unet_motion_model import MotionAdapter
|
||||
from diffusers.models.unets.unet_motion_model import MotionAdapter
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import (
|
||||
|
||||
@@ -0,0 +1,989 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.models.unet_motion_model import MotionAdapter
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
|
||||
>>> from diffusers.utils import export_to_gif, load_image
|
||||
|
||||
>>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
|
||||
>>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
|
||||
|
||||
>>> image = load_image("snail.png")
|
||||
>>> output = pipe(image=image, prompt="A snail moving on the ground", strength=0.8, latent_interpolation_method="slerp")
|
||||
>>> frames = output.frames[0]
|
||||
>>> export_to_gif(frames, "animation.gif")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def lerp(
|
||||
v0: torch.Tensor,
|
||||
v1: torch.Tensor,
|
||||
t: Union[float, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Linear Interpolation between two tensors.
|
||||
|
||||
Args:
|
||||
v0 (`torch.Tensor`): First tensor.
|
||||
v1 (`torch.Tensor`): Second tensor.
|
||||
t: (`float` or `torch.Tensor`): Interpolation factor.
|
||||
"""
|
||||
t_is_float = False
|
||||
input_device = v0.device
|
||||
v0 = v0.cpu().numpy()
|
||||
v1 = v1.cpu().numpy()
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = t.cpu().numpy()
|
||||
else:
|
||||
t_is_float = True
|
||||
t = np.array([t], dtype=v0.dtype)
|
||||
|
||||
t = t[..., None]
|
||||
v0 = v0[None, ...]
|
||||
v1 = v1[None, ...]
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
|
||||
if t_is_float and v0.ndim > 1:
|
||||
assert v2.shape[0] == 1
|
||||
v2 = np.squeeze(v2, axis=0)
|
||||
|
||||
v2 = torch.from_numpy(v2).to(input_device)
|
||||
return v2
|
||||
|
||||
|
||||
def slerp(
|
||||
v0: torch.Tensor,
|
||||
v1: torch.Tensor,
|
||||
t: Union[float, torch.Tensor],
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Spherical Linear Interpolation between two tensors.
|
||||
|
||||
Args:
|
||||
v0 (`torch.Tensor`): First tensor.
|
||||
v1 (`torch.Tensor`): Second tensor.
|
||||
t: (`float` or `torch.Tensor`): Interpolation factor.
|
||||
DOT_THRESHOLD (`float`):
|
||||
Dot product threshold exceeding which linear interpolation will be used
|
||||
because input tensors are close to parallel.
|
||||
"""
|
||||
t_is_float = False
|
||||
input_device = v0.device
|
||||
v0 = v0.cpu().numpy()
|
||||
v1 = v1.cpu().numpy()
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
t = t.cpu().numpy()
|
||||
else:
|
||||
t_is_float = True
|
||||
t = np.array([t], dtype=v0.dtype)
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
# v0 and v1 are close to parallel, so use linear interpolation instead
|
||||
v2 = lerp(v0, v1, t)
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
s0 = s0[..., None]
|
||||
s1 = s1[..., None]
|
||||
v0 = v0[None, ...]
|
||||
v1 = v1[None, ...]
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if t_is_float and v0.ndim > 1:
|
||||
assert v2.shape[0] == 1
|
||||
v2 = np.squeeze(v2, axis=0)
|
||||
|
||||
v2 = torch.from_numpy(v2).to(input_device)
|
||||
return v2
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimateDiffImgToVideoPipelineOutput(BaseOutput):
|
||||
frames: Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
class AnimateDiffImgToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
A [`~transformers.CLIPTokenizer`] to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
|
||||
motion_adapter ([`MotionAdapter`]):
|
||||
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
||||
_optional_components = ["feature_extractor", "image_encoder"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
motion_adapter: MotionAdapter,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
):
|
||||
super().__init__()
|
||||
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
motion_adapter=motion_adapter,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
image = self.vae.decode(latents).sample
|
||||
video = (
|
||||
image[None, :]
|
||||
.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_frames,
|
||||
-1,
|
||||
)
|
||||
+ image.shape[2:]
|
||||
)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
latent_interpolation_method=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if latent_interpolation_method is not None:
|
||||
if latent_interpolation_method not in ["lerp", "slerp"] and not isinstance(
|
||||
latent_interpolation_method, FunctionType
|
||||
):
|
||||
raise ValueError(
|
||||
"`latent_interpolation_method` must be one of `lerp`, `slerp` or a Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]"
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
strength,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
latent_interpolation_method="slerp",
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if image.shape[1] == 4:
|
||||
latents = image
|
||||
else:
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
if self.vae.config.force_upcast:
|
||||
self.vae.to(dtype)
|
||||
|
||||
init_latents = init_latents.to(dtype)
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
if latent_interpolation_method == "lerp":
|
||||
|
||||
def latent_cls(v0, v1, index):
|
||||
return lerp(v0, v1, index / num_frames * (1 - strength))
|
||||
elif latent_interpolation_method == "slerp":
|
||||
|
||||
def latent_cls(v0, v1, index):
|
||||
return slerp(v0, v1, index / num_frames * (1 - strength))
|
||||
else:
|
||||
latent_cls = latent_interpolation_method
|
||||
|
||||
for i in range(num_frames):
|
||||
latents[:, :, i, :, :] = latent_cls(latents[:, :, i, :, :], init_latents, i)
|
||||
else:
|
||||
if shape != latents.shape:
|
||||
# [B, C, F, H, W]
|
||||
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: int = 16,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
strength: float = 0.8,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
latent_interpolation_method: Union[str, Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]] = "slerp",
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
The input image to condition the generation on.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated video.
|
||||
num_frames (`int`, *optional*, defaults to 16):
|
||||
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
|
||||
amounts to 2 seconds of video.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
||||
expense of slower inference.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Higher strength leads to more differences between original image and generated video.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
|
||||
`(batch_size, num_channel, num_frames, height, width)`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
||||
Optional image input to work with IP Adapters.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
|
||||
`np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`AnimateDiffImgToVideoPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
latent_interpolation_method (`str` or `Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]]`, *optional*):
|
||||
Must be one of "lerp", "slerp" or a callable that takes in a random noisy latent, image latent and a frame index
|
||||
as input and returns an initial latent for sampling.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`AnimateDiffImgToVideoPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`AnimateDiffImgToVideoPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
num_videos_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
callback_steps=callback_steps,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
latent_interpolation_method=latent_interpolation_method,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Preprocess image
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
image=image,
|
||||
strength=strength,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
latent_interpolation_method=latent_interpolation_method,
|
||||
)
|
||||
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if output_type == "latent":
|
||||
return AnimateDiffImgToVideoPipelineOutput(frames=latents)
|
||||
|
||||
# 10. Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 11. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return AnimateDiffImgToVideoPipelineOutput(frames=video)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ import torch
|
||||
from diffusers import StableDiffusionControlNetPipeline
|
||||
from diffusers.models import ControlNetModel
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import logging
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
from diffusers.models.unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
|
||||
@@ -55,6 +55,9 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.26.0.dev0")
|
||||
|
||||
@@ -67,6 +70,57 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
|
||||
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
||||
|
||||
|
||||
def log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
global_step,
|
||||
is_final_validation=False,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
val_save_dir = os.path.join(args.output_dir, "validation_images")
|
||||
if not os.path.exists(val_save_dir):
|
||||
os.makedirs(val_save_dir)
|
||||
|
||||
original_image = (
|
||||
lambda image_url_or_path: load_image(image_url_or_path)
|
||||
if urlparse(image_url_or_path).scheme
|
||||
else Image.open(image_url_or_path).convert("RGB")
|
||||
)(args.val_image_url_or_path)
|
||||
|
||||
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
|
||||
edited_images = []
|
||||
# Run inference
|
||||
for val_img_idx in range(args.num_validation_images):
|
||||
a_val_img = pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
edited_images.append(a_val_img)
|
||||
# Save validation images
|
||||
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
|
||||
logger_name = "test" if is_final_validation else "validation"
|
||||
tracker.log({logger_name: wandb_table})
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(
|
||||
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
||||
):
|
||||
@@ -447,11 +501,6 @@ def main():
|
||||
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -1111,11 +1160,6 @@ def main():
|
||||
### BEGIN: Perform validation every `validation_epochs` steps
|
||||
if global_step % args.validation_steps == 0:
|
||||
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
|
||||
# create pipeline
|
||||
if args.use_ema:
|
||||
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
||||
@@ -1135,44 +1179,16 @@ def main():
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
# Save validation images
|
||||
val_save_dir = os.path.join(args.output_dir, "validation_images")
|
||||
if not os.path.exists(val_save_dir):
|
||||
os.makedirs(val_save_dir)
|
||||
log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
global_step,
|
||||
is_final_validation=False,
|
||||
)
|
||||
|
||||
original_image = (
|
||||
lambda image_url_or_path: load_image(image_url_or_path)
|
||||
if urlparse(image_url_or_path).scheme
|
||||
else Image.open(image_url_or_path).convert("RGB")
|
||||
)(args.val_image_url_or_path)
|
||||
with torch.autocast(
|
||||
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
|
||||
):
|
||||
edited_images = []
|
||||
for val_img_idx in range(args.num_validation_images):
|
||||
a_val_img = pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
edited_images.append(a_val_img)
|
||||
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(
|
||||
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
|
||||
)
|
||||
tracker.log({"validation": wandb_table})
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_unet.restore(unet.parameters())
|
||||
@@ -1187,7 +1203,6 @@ def main():
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
@@ -1198,10 +1213,11 @@ def main():
|
||||
tokenizer=tokenizer_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
unet=unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
@@ -1212,30 +1228,15 @@ def main():
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
if args.validation_prompt is not None:
|
||||
edited_images = []
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
with torch.autocast(str(accelerator.device).replace(":0", "")):
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(
|
||||
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
|
||||
)
|
||||
tracker.log({"test": wandb_table})
|
||||
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
|
||||
log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
global_step,
|
||||
is_final_validation=True,
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProc
|
||||
from diffusers.models.autoencoders import AutoencoderKL
|
||||
from diffusers.models.lora import LoRACompatibleConv
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
from diffusers.models.unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
@@ -36,7 +36,7 @@ from diffusers.models.unet_2d_blocks import (
|
||||
UpBlock2D,
|
||||
Upsample2D,
|
||||
)
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
|
||||
|
||||
|
||||
@@ -1041,11 +1041,6 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# manually for max memory savings
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
@@ -740,6 +740,10 @@ def main(args):
|
||||
# Resize.
|
||||
combined_im = train_resize(combined_im)
|
||||
|
||||
# Flipping.
|
||||
if not args.no_flip and random.random() < 0.5:
|
||||
combined_im = train_flip(combined_im)
|
||||
|
||||
# Cropping.
|
||||
if not args.random_crop:
|
||||
y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
|
||||
@@ -749,11 +753,6 @@ def main(args):
|
||||
y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
|
||||
combined_im = crop(combined_im, y1, x1, h, w)
|
||||
|
||||
# Flipping.
|
||||
if random.random() < 0.5:
|
||||
x1 = combined_im.shape[2] - x1
|
||||
combined_im = train_flip(combined_im)
|
||||
|
||||
crop_top_left = (y1, x1)
|
||||
crop_top_lefts.append(crop_top_left)
|
||||
combined_im = normalize(combined_im)
|
||||
|
||||
@@ -183,6 +183,66 @@ The above command will also run inference as fine-tuning progresses and log the
|
||||
|
||||
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
|
||||
### Using DeepSpeed
|
||||
Using DeepSpeed one can reduce the consumption of GPU memory, enabling the training of models on GPUs with smaller memory sizes. DeepSpeed is capable of offloading model parameters to the machine's memory, or it can distribute parameters, gradients, and optimizer states across multiple GPUs. This allows for the training of larger models under the same hardware configuration.
|
||||
|
||||
First, you need to use the `accelerate config` command to choose to use DeepSpeed, or manually use the accelerate config file to set up DeepSpeed.
|
||||
|
||||
Here is an example of a config file for using DeepSpeed. For more detailed explanations of the configuration, you can refer to this [link](https://huggingface.co/docs/accelerate/usage_guides/deepspeed).
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: true
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
You need to save the mentioned configuration as an `accelerate_config.yaml` file. Then, you need to input the path of your `accelerate_config.yaml` file into the `ACCELERATE_CONFIG_FILE` parameter. This way you can use DeepSpeed to train your SDXL model in LoRA. Additionally, you can use DeepSpeed to train other SD models in this way.
|
||||
|
||||
```shell
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
export ACCELERATE_CONFIG_FILE="your accelerate_config.yaml"
|
||||
|
||||
accelerate launch --config_file $ACCELERATE_CONFIG_FILE train_text_to_image_lora_sdxl.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--pretrained_vae_model_name_or_path=$VAE_NAME \
|
||||
--dataset_name=$DATASET_NAME --caption_column="text" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--num_train_epochs=2 \
|
||||
--checkpointing_steps=2 \
|
||||
--learning_rate=1e-04 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--mixed_precision="fp16" \
|
||||
--max_train_steps=20 \
|
||||
--validation_epochs=20 \
|
||||
--seed=1234 \
|
||||
--output_dir="sd-pokemon-model-lora-sdxl" \
|
||||
--validation_prompt="cute dragon creature"
|
||||
|
||||
```
|
||||
|
||||
|
||||
### Finetuning the text encoder and UNet
|
||||
|
||||
The script also allows you to finetune the `text_encoder` along with the `unet`.
|
||||
|
||||
@@ -652,13 +652,13 @@ def main(args):
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
@@ -666,7 +666,8 @@ def main(args):
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
|
||||
@@ -10,7 +10,7 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import VQModel
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.uvit_2d import UVit2DModel
|
||||
from diffusers.models.unets.uvit_2d import UVit2DModel
|
||||
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
|
||||
from diffusers.schedulers import AmusedScheduler
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from tqdm import tqdm
|
||||
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
|
||||
from diffusers.models.autoencoders.vae import Encoder
|
||||
from diffusers.models.embeddings import TimestepEmbedding
|
||||
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
|
||||
from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
|
||||
|
||||
|
||||
args = ArgumentParser()
|
||||
|
||||
@@ -153,6 +153,7 @@ else:
|
||||
"LCMScheduler",
|
||||
"PNDMScheduler",
|
||||
"RePaintScheduler",
|
||||
"SASolverScheduler",
|
||||
"SchedulerMixin",
|
||||
"ScoreSdeVeScheduler",
|
||||
"UnCLIPScheduler",
|
||||
@@ -207,6 +208,7 @@ else:
|
||||
"AmusedInpaintPipeline",
|
||||
"AmusedPipeline",
|
||||
"AnimateDiffPipeline",
|
||||
"AnimateDiffVideoToVideoPipeline",
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
@@ -381,7 +383,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
|
||||
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
|
||||
_import_structure["schedulers"].extend(
|
||||
@@ -530,6 +532,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LCMScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SASolverScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
UnCLIPScheduler,
|
||||
@@ -567,6 +570,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AmusedInpaintPipeline,
|
||||
AmusedPipeline,
|
||||
AnimateDiffPipeline,
|
||||
AnimateDiffVideoToVideoPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
@@ -709,7 +713,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .models.controlnet_flax import FlaxControlNetModel
|
||||
from .models.modeling_flax_utils import FlaxModelMixin
|
||||
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .models.vae_flax import FlaxAutoencoderKL
|
||||
from .pipelines import FlaxDiffusionPipeline
|
||||
from .schedulers import (
|
||||
|
||||
@@ -16,7 +16,7 @@ import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from ...models.unet_1d import UNet1DModel
|
||||
from ...models.unets.unet_1d import UNet1DModel
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...utils.dummy_pt_objects import DDPMScheduler
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
@@ -54,12 +54,13 @@ if is_transformers_available():
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
|
||||
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
|
||||
|
||||
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
|
||||
if is_transformers_available():
|
||||
_import_structure["single_file"].extend(["FromSingleFileMixin"])
|
||||
_import_structure["single_file"] = ["FromSingleFileMixin"]
|
||||
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
@@ -69,7 +70,8 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
|
||||
from .autoencoder import FromOriginalVAEMixin
|
||||
from .controlnet import FromOriginalControlNetMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from .single_file_utils import (
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
"""
|
||||
Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
|
||||
a VAE from SDXL or a Stable Diffusion v2 model or higher.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
|
||||
vae = component["vae"]
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
return vae
|
||||
@@ -0,0 +1,127 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from .single_file_utils import (
|
||||
create_diffusers_controlnet_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class FromOriginalControlNetMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
upcast_attention = kwargs.pop("upcast_attention", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
|
||||
component = create_diffusers_controlnet_model_from_ldm(
|
||||
class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size
|
||||
)
|
||||
controlnet = component["controlnet"]
|
||||
if torch_dtype is not None:
|
||||
controlnet = controlnet.to(torch_dtype)
|
||||
|
||||
return controlnet
|
||||
@@ -11,39 +11,132 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import deprecate, is_accelerate_available, is_transformers_available, logging
|
||||
from ..utils import is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
create_diffusers_unet_model_from_ldm,
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
create_scheduler_from_ldm,
|
||||
create_text_encoders_and_tokenizers_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
infer_model_type,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
pass
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Pipelines that support the SDXL Refiner checkpoint
|
||||
REFINER_PIPELINES = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
def build_sub_model_components(
|
||||
pipeline_components,
|
||||
pipeline_class_name,
|
||||
component_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
local_files_only=False,
|
||||
load_safety_checker=False,
|
||||
model_type=None,
|
||||
image_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
if component_name in pipeline_components:
|
||||
return {}
|
||||
|
||||
if component_name == "unet":
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
unet_components = create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size
|
||||
)
|
||||
return unet_components
|
||||
|
||||
if component_name == "vae":
|
||||
vae_components = create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size
|
||||
)
|
||||
return vae_components
|
||||
|
||||
if component_name == "scheduler":
|
||||
scheduler_type = kwargs.get("scheduler_type", "ddim")
|
||||
prediction_type = kwargs.get("prediction_type", None)
|
||||
|
||||
scheduler_components = create_scheduler_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
scheduler_type=scheduler_type,
|
||||
prediction_type=prediction_type,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
return scheduler_components
|
||||
|
||||
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
|
||||
text_encoder_components = create_text_encoders_and_tokenizers_from_ldm(
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return text_encoder_components
|
||||
|
||||
if component_name == "safety_checker":
|
||||
if load_safety_checker:
|
||||
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
safety_checker = None
|
||||
return {"safety_checker": safety_checker}
|
||||
|
||||
if component_name == "feature_extractor":
|
||||
if load_safety_checker:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
feature_extractor = None
|
||||
return {"feature_extractor": feature_extractor}
|
||||
|
||||
return
|
||||
|
||||
|
||||
def set_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
model_type=None,
|
||||
):
|
||||
components = {}
|
||||
if pipeline_class_name in REFINER_PIPELINES:
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
is_refiner = model_type == "SDXL-Refiner"
|
||||
components.update(
|
||||
{
|
||||
"requires_aesthetics_score": is_refiner,
|
||||
"force_zeros_for_empty_prompt": False if is_refiner else True,
|
||||
}
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, *args, **kwargs):
|
||||
deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
|
||||
deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
|
||||
return cls.from_single_file(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
@@ -58,8 +151,7 @@ class FromSingleFileMixin:
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
@@ -85,42 +177,6 @@ class FromSingleFileMixin:
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`):
|
||||
Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
|
||||
higher quality images for inference. Non-EMA weights are usually better for continuing finetuning.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
prediction_type (`str`, *optional*):
|
||||
The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
|
||||
the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
|
||||
num_in_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of input channels. If `None`, it is automatically inferred.
|
||||
scheduler_type (`str`, *optional*, defaults to `"pndm"`):
|
||||
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
||||
"ddim"]`.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
||||
Whether to load the safety checker or not.
|
||||
text_encoder ([`~transformers.CLIPTextModel`], *optional*, defaults to `None`):
|
||||
An instance of `CLIPTextModel` to use, specifically the
|
||||
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
|
||||
parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
|
||||
vae (`AutoencoderKL`, *optional*, defaults to `None`):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
|
||||
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
|
||||
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
|
||||
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
|
||||
of `CLIPTokenizer` by itself if needed.
|
||||
original_config_file (`str`):
|
||||
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be
|
||||
automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
@@ -143,484 +199,80 @@ class FromSingleFileMixin:
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_files = kwargs.pop("config_files", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
text_encoder = kwargs.pop("text_encoder", None)
|
||||
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
||||
vae = kwargs.pop("vae", None)
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
adapter = kwargs.pop("adapter", None)
|
||||
tokenizer = kwargs.pop("tokenizer", None)
|
||||
tokenizer_2 = kwargs.pop("tokenizer_2", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
class_name = cls.__name__
|
||||
|
||||
pipeline_name = cls.__name__
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# TODO: For now we only support stable diffusion
|
||||
stable_unclip = None
|
||||
model_type = None
|
||||
|
||||
if pipeline_name in [
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
]:
|
||||
from ..models.controlnet import ControlNetModel
|
||||
from ..pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
|
||||
# list/tuple or a single instance of ControlNetModel or MultiControlNetModel
|
||||
if not (
|
||||
isinstance(controlnet, (ControlNetModel, MultiControlNetModel))
|
||||
or isinstance(controlnet, (list, tuple))
|
||||
and isinstance(controlnet[0], ControlNetModel)
|
||||
):
|
||||
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
# Model type will be inferred from the checkpoint.
|
||||
pass
|
||||
elif pipeline_name == "StableUnCLIPPipeline":
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "txt2img"
|
||||
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "img2img"
|
||||
elif pipeline_name == "PaintByExamplePipeline":
|
||||
model_type = "PaintByExample"
|
||||
elif pipeline_name == "LDMTextToImagePipeline":
|
||||
model_type = "LDMTextToImage"
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
# remove huggingface url
|
||||
has_valid_url_prefix = False
|
||||
valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
||||
for prefix in valid_url_prefixes:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
has_valid_url_prefix = True
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
if not has_valid_url_prefix:
|
||||
raise ValueError(
|
||||
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}"
|
||||
)
|
||||
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
pipeline_class=cls,
|
||||
model_type=model_type,
|
||||
stable_unclip=stable_unclip,
|
||||
controlnet=controlnet,
|
||||
adapter=adapter,
|
||||
from_safetensors=from_safetensors,
|
||||
extract_ema=extract_ema,
|
||||
image_size=image_size,
|
||||
scheduler_type=scheduler_type,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=load_safety_checker,
|
||||
prediction_type=prediction_type,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
config_files=config_files,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config=None,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
model_type = kwargs.pop("model_type", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
load_safety_checker = (kwargs.pop("load_safety_checker", False)) or (
|
||||
passed_class_obj.get("safety_checker", None) is not None
|
||||
)
|
||||
|
||||
init_kwargs = {}
|
||||
for name in expected_modules:
|
||||
if name in passed_class_obj:
|
||||
init_kwargs[name] = passed_class_obj[name]
|
||||
else:
|
||||
components = build_sub_model_components(
|
||||
init_kwargs,
|
||||
class_name,
|
||||
name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
image_size=image_size,
|
||||
load_safety_checker=load_safety_checker,
|
||||
local_files_only=local_files_only,
|
||||
**kwargs,
|
||||
)
|
||||
if not components:
|
||||
continue
|
||||
init_kwargs.update(components)
|
||||
|
||||
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
|
||||
if additional_components:
|
||||
init_kwargs.update(additional_components)
|
||||
|
||||
init_kwargs.update(passed_pipe_kwargs)
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into an [`AutoencoderKL`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
|
||||
a VAE from SDXL or a Stable Diffusion v2 model or higher.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
)
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||
kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
checkpoint[key] = f.get_tensor(key)
|
||||
else:
|
||||
checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = yaml.safe_load(config_file)
|
||||
|
||||
# default to sd-v1-5
|
||||
image_size = image_size or 512
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
if scaling_factor is None:
|
||||
if (
|
||||
"model" in original_config
|
||||
and "params" in original_config["model"]
|
||||
and "scale_factor" in original_config["model"]["params"]
|
||||
):
|
||||
vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
else:
|
||||
vae_scaling_factor = 0.18215 # default SD scaling factor
|
||||
|
||||
vae_config["scaling_factor"] = vae_scaling_factor
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
|
||||
else:
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
if torch_dtype is not None:
|
||||
vae.to(dtype=torch_dtype)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
class FromOriginalControlnetMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
|
||||
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
use_linear_projection = kwargs.pop("use_linear_projection", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
if config_file is None:
|
||||
config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml"
|
||||
config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
image_size = image_size or 512
|
||||
|
||||
controlnet = download_controlnet_from_original_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
original_config_file=config_file,
|
||||
image_size=image_size,
|
||||
extract_ema=extract_ema,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
from_safetensors=from_safetensors,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
controlnet.to(dtype=torch_dtype)
|
||||
|
||||
return controlnet
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -39,19 +39,19 @@ if is_torch_available():
|
||||
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
|
||||
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
_import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
|
||||
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
|
||||
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
||||
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
_import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
|
||||
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
|
||||
|
||||
|
||||
@@ -73,19 +73,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .uvit_2d import UVit2DModel
|
||||
from .unets import (
|
||||
Kandinsky3UNet,
|
||||
MotionAdapter,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
UNet3DConditionModel,
|
||||
UNetMotionModel,
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
)
|
||||
from .vq_model import VQModel
|
||||
|
||||
if is_flax_available():
|
||||
from .controlnet_flax import FlaxControlNetModel
|
||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
from .unets import FlaxUNet2DConditionModel
|
||||
from .vae_flax import FlaxAutoencoderKL
|
||||
|
||||
else:
|
||||
|
||||
@@ -157,7 +157,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
self.use_slicing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -181,7 +181,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -216,7 +216,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -448,7 +448,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||
@@ -472,7 +472,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
|
||||
@@ -17,13 +17,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalVAEMixin
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@@ -162,7 +161,7 @@ class TemporalDecoder(nn.Module):
|
||||
return sample
|
||||
|
||||
|
||||
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
@@ -242,7 +241,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -266,7 +265,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -31,7 +31,7 @@ from ..attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unet_2d import UNet2DModel
|
||||
from ..unets.unet_2d import UNet2DModel
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
self.use_slicing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -246,7 +246,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...utils import BaseOutput, is_torch_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import SpatialNorm
|
||||
from ..unet_2d_blocks import (
|
||||
from ..unets.unet_2d_blocks import (
|
||||
AutoencoderTinyBlock,
|
||||
UNetMidBlock2D,
|
||||
get_down_block,
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlnetMixin
|
||||
from ..loaders import FromOriginalControlNetMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -30,8 +30,14 @@ from .attention_processor import (
|
||||
)
|
||||
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
get_down_block,
|
||||
)
|
||||
from .unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -102,7 +108,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
@@ -509,7 +515,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -533,7 +539,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -568,7 +574,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -584,7 +590,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
from .unet_2d_blocks_flax import (
|
||||
from .unets.unet_2d_blocks_flax import (
|
||||
FlaxCrossAttnDownBlock2D,
|
||||
FlaxDownBlock2D,
|
||||
FlaxUNetMidBlock2DCrossAttn,
|
||||
@@ -329,14 +329,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
|
||||
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
plain tuple.
|
||||
train (`bool`, *optional*, defaults to `False`):
|
||||
Use deterministic functions and disable dropout when not training.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
channel_order = self.controlnet_conditioning_channel_order
|
||||
|
||||
@@ -120,7 +120,7 @@ class DualTransformer2DModel(nn.Module):
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
|
||||
@@ -32,6 +32,7 @@ from .. import __version__
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
SAFETENSORS_FILE_EXTENSION,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
_add_variant,
|
||||
@@ -102,10 +103,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
else:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
else:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
|
||||
@@ -167,7 +167,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -191,7 +191,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -226,7 +226,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
|
||||
@@ -286,7 +286,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -149,7 +149,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -12,244 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
from ..utils import deprecate
|
||||
from .unets.unet_1d import UNet1DModel, UNet1DOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet1DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet1DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
class UNet1DOutput(UNet1DOutput):
|
||||
deprecation_message = "Importing `UNet1DOutput` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DOutput`, instead."
|
||||
deprecate("UNet1DOutput", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
||||
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
||||
extra_in_channels (`int`, *optional*, defaults to 0):
|
||||
Number of additional channels to be added to the input of the first down block. Useful for cases where the
|
||||
input data has more channels than what the model was initially designed for.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
||||
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
|
||||
Tuple of block output channels.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
|
||||
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
|
||||
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
|
||||
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
|
||||
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
|
||||
downsample_each_block (`int`, *optional*, defaults to `False`):
|
||||
Experimental feature for using a UNet without upsampling.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 65536,
|
||||
sample_rate: Optional[int] = None,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
layers_per_block: int = 1,
|
||||
downsample_each_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.sample_size = sample_size
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=timestep_input_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
out_dim=block_out_channels[0],
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
self.out_block = None
|
||||
|
||||
# down
|
||||
output_channel = in_channels
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
if i == 0:
|
||||
input_channel += extra_in_channels
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_downsample=not is_final_block or downsample_each_block,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = get_mid_block(
|
||||
mid_block_type,
|
||||
in_channels=block_out_channels[-1],
|
||||
mid_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
embed_dim=block_out_channels[0],
|
||||
num_layers=layers_per_block,
|
||||
add_downsample=downsample_each_block,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
if out_block_type is None:
|
||||
final_upsample_channels = out_channels
|
||||
else:
|
||||
final_upsample_channels = block_out_channels[0]
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = (
|
||||
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
|
||||
)
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_upsample=not is_final_block,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.out_block = get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=block_out_channels[0],
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet1DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet1DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
timestep_embed = self.time_proj(timesteps)
|
||||
if self.config.use_timestep_embedding:
|
||||
timestep_embed = self.time_mlp(timestep_embed)
|
||||
else:
|
||||
timestep_embed = timestep_embed[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
for downsample_block in self.down_blocks:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 3. mid
|
||||
if self.mid_block:
|
||||
sample = self.mid_block(sample, timestep_embed)
|
||||
|
||||
# 4. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
res_samples = down_block_res_samples[-1:]
|
||||
down_block_res_samples = down_block_res_samples[:-1]
|
||||
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
|
||||
|
||||
# 5. post-process
|
||||
if self.out_block:
|
||||
sample = self.out_block(sample, timestep_embed)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet1DOutput(sample=sample)
|
||||
class UNet1DModel(UNet1DModel):
|
||||
deprecation_message = "Importing `UNet1DModel` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DModel`, instead."
|
||||
deprecate("UNet1DModel", "0.29", deprecation_message)
|
||||
|
||||
@@ -11,616 +11,112 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .activations import get_activation
|
||||
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
||||
|
||||
|
||||
class DownResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
conv_shortcut: bool = False,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_downsample = add_downsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
output_states = ()
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.downsample is not None:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UpResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_upsample = add_upsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if res_hidden_states_tuple is not None:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ValueFunctionMidBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
|
||||
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
|
||||
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
||||
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
x = self.res1(x, temb)
|
||||
x = self.down1(x)
|
||||
x = self.res2(x, temb)
|
||||
x = self.down2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MidResTemporalBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
num_layers: int = 1,
|
||||
add_downsample: bool = False,
|
||||
add_upsample: bool = False,
|
||||
non_linearity: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.add_downsample = add_downsample
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
if self.upsample and self.downsample:
|
||||
raise ValueError("Block cannot downsample and upsample")
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
if self.downsample:
|
||||
self.downsample = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutConv1DBlock(nn.Module):
|
||||
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
|
||||
super().__init__()
|
||||
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
||||
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
||||
self.final_conv1d_act = get_activation(act_fn)
|
||||
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.final_conv1d_1(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_gn(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_act(hidden_states)
|
||||
hidden_states = self.final_conv1d_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutValueFunctionBlock(nn.Module):
|
||||
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
|
||||
super().__init__()
|
||||
self.final_block = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
||||
get_activation(act_fn),
|
||||
nn.Linear(fc_dim // 2, 1),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
||||
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
||||
for layer in self.final_block:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
_kernels = {
|
||||
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
"lanczos3": [
|
||||
0.003689131001010537,
|
||||
0.015056144446134567,
|
||||
-0.03399861603975296,
|
||||
-0.066637322306633,
|
||||
0.13550527393817902,
|
||||
0.44638532400131226,
|
||||
0.44638532400131226,
|
||||
0.13550527393817902,
|
||||
-0.066637322306633,
|
||||
-0.03399861603975296,
|
||||
0.015056144446134567,
|
||||
0.003689131001010537,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv1d(hidden_states, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
|
||||
super().__init__()
|
||||
self.channels = in_channels
|
||||
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
||||
self.num_heads = n_head
|
||||
|
||||
self.query = nn.Linear(self.channels, self.channels)
|
||||
self.key = nn.Linear(self.channels, self.channels)
|
||||
self.value = nn.Linear(self.channels, self.channels)
|
||||
|
||||
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
|
||||
|
||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||
|
||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
||||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
||||
return new_projection
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
batch, channel_dim, seq = hidden_states.shape
|
||||
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
|
||||
|
||||
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
||||
attention_probs = torch.softmax(attention_scores, dim=-1)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.matmul(attention_probs, value_states)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResConvBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
|
||||
super().__init__()
|
||||
self.is_last = is_last
|
||||
self.has_conv_skip = in_channels != out_channels
|
||||
|
||||
if self.has_conv_skip:
|
||||
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
|
||||
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
|
||||
self.gelu_1 = nn.GELU()
|
||||
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
|
||||
|
||||
if not self.is_last:
|
||||
self.group_norm_2 = nn.GroupNorm(1, out_channels)
|
||||
self.gelu_2 = nn.GELU()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
|
||||
|
||||
hidden_states = self.conv_1(hidden_states)
|
||||
hidden_states = self.group_norm_1(hidden_states)
|
||||
hidden_states = self.gelu_1(hidden_states)
|
||||
hidden_states = self.conv_2(hidden_states)
|
||||
from ..utils import deprecate
|
||||
from .unets.unet_1d_blocks import (
|
||||
AttnDownBlock1D,
|
||||
AttnUpBlock1D,
|
||||
DownBlock1D,
|
||||
DownBlock1DNoSkip,
|
||||
DownResnetBlock1D,
|
||||
Downsample1d,
|
||||
MidResTemporalBlock1D,
|
||||
OutConv1DBlock,
|
||||
OutValueFunctionBlock,
|
||||
ResConvBlock,
|
||||
SelfAttention1d,
|
||||
UNetMidBlock1D,
|
||||
UpBlock1D,
|
||||
UpBlock1DNoSkip,
|
||||
UpResnetBlock1D,
|
||||
Upsample1d,
|
||||
ValueFunctionMidBlock1D,
|
||||
)
|
||||
|
||||
if not self.is_last:
|
||||
hidden_states = self.group_norm_2(hidden_states)
|
||||
hidden_states = self.gelu_2(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
return output
|
||||
class DownResnetBlock1D(DownResnetBlock1D):
|
||||
deprecation_message = "Importing `DownResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownResnetBlock1D`, instead."
|
||||
deprecate("DownResnetBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UNetMidBlock1D(nn.Module):
|
||||
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
class UpResnetBlock1D(UpResnetBlock1D):
|
||||
deprecation_message = "Importing `UpResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpResnetBlock1D`, instead."
|
||||
deprecate("UpResnetBlock1D", "0.29", deprecation_message)
|
||||
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
class ValueFunctionMidBlock1D(ValueFunctionMidBlock1D):
|
||||
deprecation_message = "Importing `ValueFunctionMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ValueFunctionMidBlock1D`, instead."
|
||||
deprecate("ValueFunctionMidBlock1D", "0.29", deprecation_message)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
class OutConv1DBlock(OutConv1DBlock):
|
||||
deprecation_message = "Importing `OutConv1DBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutConv1DBlock`, instead."
|
||||
deprecate("OutConv1DBlock", "0.29", deprecation_message)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
class OutValueFunctionBlock(OutValueFunctionBlock):
|
||||
deprecation_message = "Importing `OutValueFunctionBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutValueFunctionBlock`, instead."
|
||||
deprecate("OutValueFunctionBlock", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class AttnDownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
class Downsample1d(Downsample1d):
|
||||
deprecation_message = "Importing `Downsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Downsample1d`, instead."
|
||||
deprecate("Downsample1d", "0.29", deprecation_message)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
class Upsample1d(Upsample1d):
|
||||
deprecation_message = "Importing `Upsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Upsample1d`, instead."
|
||||
deprecate("Upsample1d", "0.29", deprecation_message)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = torch.cat([hidden_states, temb], dim=1)
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class AttnUpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
class SelfAttention1d(SelfAttention1d):
|
||||
deprecation_message = "Importing `SelfAttention1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import SelfAttention1d`, instead."
|
||||
deprecate("SelfAttention1d", "0.29", deprecation_message)
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
class ResConvBlock(ResConvBlock):
|
||||
deprecation_message = "Importing `ResConvBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ResConvBlock`, instead."
|
||||
deprecate("ResConvBlock", "0.29", deprecation_message)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
class UNetMidBlock1D(UNetMidBlock1D):
|
||||
deprecation_message = "Importing `UNetMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UNetMidBlock1D`, instead."
|
||||
deprecate("UNetMidBlock1D", "0.29", deprecation_message)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
class AttnDownBlock1D(AttnDownBlock1D):
|
||||
deprecation_message = "Importing `AttnDownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnDownBlock1D`, instead."
|
||||
deprecate("AttnDownBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
class DownBlock1D(DownBlock1D):
|
||||
deprecation_message = "Importing `DownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1D`, instead."
|
||||
deprecate("DownBlock1D", "0.29", deprecation_message)
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
class DownBlock1DNoSkip(DownBlock1DNoSkip):
|
||||
deprecation_message = "Importing `DownBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1DNoSkip`, instead."
|
||||
deprecate("DownBlock1DNoSkip", "0.29", deprecation_message)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
class AttnUpBlock1D(AttnUpBlock1D):
|
||||
deprecation_message = "Importing `AttnUpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnUpBlock1D`, instead."
|
||||
deprecate("AttnUpBlock1D", "0.29", deprecation_message)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
class UpBlock1D(UpBlock1D):
|
||||
deprecation_message = "Importing `UpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1D`, instead."
|
||||
deprecate("UpBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
class UpBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
class UpBlock1DNoSkip(UpBlock1DNoSkip):
|
||||
deprecation_message = "Importing `UpBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1DNoSkip`, instead."
|
||||
deprecate("UpBlock1DNoSkip", "0.29", deprecation_message)
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
|
||||
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
|
||||
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
|
||||
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
|
||||
class MidResTemporalBlock1D(MidResTemporalBlock1D):
|
||||
deprecation_message = "Importing `MidResTemporalBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import MidResTemporalBlock1D`, instead."
|
||||
deprecate("MidResTemporalBlock1D", "0.29", deprecation_message)
|
||||
|
||||
|
||||
def get_down_block(
|
||||
@@ -630,42 +126,38 @@ def get_down_block(
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
add_downsample: bool,
|
||||
) -> DownBlockType:
|
||||
if down_block_type == "DownResnetBlock1D":
|
||||
return DownResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
):
|
||||
deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_down_block`, instead."
|
||||
deprecate("get_down_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_down_block
|
||||
|
||||
return get_down_block(
|
||||
down_block_type=down_block_type,
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
|
||||
|
||||
def get_up_block(
|
||||
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
|
||||
) -> UpBlockType:
|
||||
if up_block_type == "UpResnetBlock1D":
|
||||
return UpResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
elif up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
):
|
||||
deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_up_block`, instead."
|
||||
deprecate("get_up_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_up_block
|
||||
|
||||
return get_up_block(
|
||||
up_block_type=up_block_type,
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
|
||||
|
||||
def get_mid_block(
|
||||
@@ -676,27 +168,36 @@ def get_mid_block(
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
add_downsample: bool,
|
||||
) -> MidBlockType:
|
||||
if mid_block_type == "MidResTemporalBlock1D":
|
||||
return MidResTemporalBlock1D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif mid_block_type == "ValueFunctionMidBlock1D":
|
||||
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
|
||||
elif mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
):
|
||||
deprecation_message = "Importing `get_mid_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_mid_block`, instead."
|
||||
deprecate("get_mid_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_mid_block
|
||||
|
||||
return get_mid_block(
|
||||
mid_block_type=mid_block_type,
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
mid_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
|
||||
|
||||
def get_out_block(
|
||||
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
|
||||
) -> Optional[OutBlockType]:
|
||||
if out_block_type == "OutConv1DBlock":
|
||||
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
||||
elif out_block_type == "ValueFunction":
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
|
||||
return None
|
||||
):
|
||||
deprecation_message = "Importing `get_out_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_out_block`, instead."
|
||||
deprecate("get_out_block", "0.29", deprecation_message)
|
||||
|
||||
from .unets.unet_1d_blocks import get_out_block
|
||||
|
||||
return get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=embed_dim,
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=fc_dim,
|
||||
)
|
||||
|
||||
@@ -11,336 +11,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
from ..utils import deprecate
|
||||
from .unets.unet_2d import UNet2DModel, UNet2DOutput
|
||||
|
||||
|
||||
class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
class UNet2DOutput(UNet2DOutput):
|
||||
deprecation_message = "Importing `UNet2DOutput` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DOutput`, instead."
|
||||
deprecate("UNet2DOutput", "0.29", deprecation_message)
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
||||
1)`.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
||||
Tuple of downsample block types.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
||||
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
||||
Tuple of block output channels.
|
||||
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
||||
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
||||
downsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
||||
upsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
||||
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
|
||||
given number of groups. If left as `None`, the group norm layer will only be created if
|
||||
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
|
||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to `None`):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
|
||||
conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
center_input_sample: bool = False,
|
||||
time_embedding_type: str = "positional",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
||||
layers_per_block: int = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
downsample_padding: int = 1,
|
||||
downsample_type: str = "conv",
|
||||
upsample_type: str = "conv",
|
||||
dropout: float = 0.0,
|
||||
act_fn: str = "silu",
|
||||
attention_head_dim: Optional[int] = 8,
|
||||
norm_num_groups: int = 32,
|
||||
attn_norm_num_groups: Optional[int] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
add_attention: bool = True,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
elif time_embedding_type == "learned":
|
||||
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_groups=attn_norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when doing class conditioning")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
elif self.class_embedding is None and class_labels is not None:
|
||||
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "skip_conv"):
|
||||
sample, res_samples, skip_sample = downsample_block(
|
||||
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "skip_conv"):
|
||||
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
||||
else:
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if skip_sample is not None:
|
||||
sample += skip_sample
|
||||
|
||||
if self.config.time_embedding_type == "fourier":
|
||||
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
||||
sample = sample / timesteps
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DOutput(sample=sample)
|
||||
class UNet2DModel(UNet2DModel):
|
||||
deprecation_message = "Importing `UNet2DModel` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DModel`, instead."
|
||||
deprecate("UNet2DModel", "0.29", deprecation_message)
|
||||
|
||||
+213
-3429
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,16 @@
|
||||
from ...utils import is_flax_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_condition import UNet3DConditionModel
|
||||
from .unet_kandinsky3 import Kandinsky3UNet
|
||||
from .unet_motion_model import MotionAdapter, UNetMotionModel
|
||||
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
||||
from .uvit_2d import UVit2DModel
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
|
||||
@@ -0,0 +1,255 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet1DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet1DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
||||
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
||||
extra_in_channels (`int`, *optional*, defaults to 0):
|
||||
Number of additional channels to be added to the input of the first down block. Useful for cases where the
|
||||
input data has more channels than what the model was initially designed for.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
||||
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
|
||||
Tuple of block output channels.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
|
||||
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
|
||||
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
|
||||
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
|
||||
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
|
||||
downsample_each_block (`int`, *optional*, defaults to `False`):
|
||||
Experimental feature for using a UNet without upsampling.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 65536,
|
||||
sample_rate: Optional[int] = None,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
layers_per_block: int = 1,
|
||||
downsample_each_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.sample_size = sample_size
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(
|
||||
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=timestep_input_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
out_dim=block_out_channels[0],
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
self.out_block = None
|
||||
|
||||
# down
|
||||
output_channel = in_channels
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
if i == 0:
|
||||
input_channel += extra_in_channels
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_downsample=not is_final_block or downsample_each_block,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = get_mid_block(
|
||||
mid_block_type,
|
||||
in_channels=block_out_channels[-1],
|
||||
mid_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
embed_dim=block_out_channels[0],
|
||||
num_layers=layers_per_block,
|
||||
add_downsample=downsample_each_block,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
if out_block_type is None:
|
||||
final_upsample_channels = out_channels
|
||||
else:
|
||||
final_upsample_channels = block_out_channels[0]
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = (
|
||||
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
|
||||
)
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_upsample=not is_final_block,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.out_block = get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=block_out_channels[0],
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet1DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet1DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
timestep_embed = self.time_proj(timesteps)
|
||||
if self.config.use_timestep_embedding:
|
||||
timestep_embed = self.time_mlp(timestep_embed)
|
||||
else:
|
||||
timestep_embed = timestep_embed[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
for downsample_block in self.down_blocks:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 3. mid
|
||||
if self.mid_block:
|
||||
sample = self.mid_block(sample, timestep_embed)
|
||||
|
||||
# 4. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
res_samples = down_block_res_samples[-1:]
|
||||
down_block_res_samples = down_block_res_samples[:-1]
|
||||
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
|
||||
|
||||
# 5. post-process
|
||||
if self.out_block:
|
||||
sample = self.out_block(sample, timestep_embed)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet1DOutput(sample=sample)
|
||||
@@ -0,0 +1,702 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..activations import get_activation
|
||||
from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
||||
|
||||
|
||||
class DownResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
conv_shortcut: bool = False,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_downsample = add_downsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
output_states = ()
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.downsample is not None:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UpResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
temb_channels: int = 32,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
non_linearity: Optional[str] = None,
|
||||
time_embedding_norm: str = "default",
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_upsample = add_upsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if res_hidden_states_tuple is not None:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ValueFunctionMidBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
|
||||
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
|
||||
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
||||
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
x = self.res1(x, temb)
|
||||
x = self.down1(x)
|
||||
x = self.res2(x, temb)
|
||||
x = self.down2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MidResTemporalBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
num_layers: int = 1,
|
||||
add_downsample: bool = False,
|
||||
add_upsample: bool = False,
|
||||
non_linearity: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.add_downsample = add_downsample
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity is None:
|
||||
self.nonlinearity = None
|
||||
else:
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
if self.upsample and self.downsample:
|
||||
raise ValueError("Block cannot downsample and upsample")
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
if self.downsample:
|
||||
self.downsample = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutConv1DBlock(nn.Module):
|
||||
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
|
||||
super().__init__()
|
||||
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
||||
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
||||
self.final_conv1d_act = get_activation(act_fn)
|
||||
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.final_conv1d_1(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_gn(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_act(hidden_states)
|
||||
hidden_states = self.final_conv1d_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutValueFunctionBlock(nn.Module):
|
||||
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
|
||||
super().__init__()
|
||||
self.final_block = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
||||
get_activation(act_fn),
|
||||
nn.Linear(fc_dim // 2, 1),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
||||
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
||||
for layer in self.final_block:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
_kernels = {
|
||||
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
|
||||
"lanczos3": [
|
||||
0.003689131001010537,
|
||||
0.015056144446134567,
|
||||
-0.03399861603975296,
|
||||
-0.066637322306633,
|
||||
0.13550527393817902,
|
||||
0.44638532400131226,
|
||||
0.44638532400131226,
|
||||
0.13550527393817902,
|
||||
-0.066637322306633,
|
||||
-0.03399861603975296,
|
||||
0.015056144446134567,
|
||||
0.003689131001010537,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel])
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv1d(hidden_states, weight, stride=2)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
class SelfAttention1d(nn.Module):
|
||||
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
|
||||
super().__init__()
|
||||
self.channels = in_channels
|
||||
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
|
||||
self.num_heads = n_head
|
||||
|
||||
self.query = nn.Linear(self.channels, self.channels)
|
||||
self.key = nn.Linear(self.channels, self.channels)
|
||||
self.value = nn.Linear(self.channels, self.channels)
|
||||
|
||||
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
|
||||
|
||||
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
||||
|
||||
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
|
||||
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
|
||||
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
|
||||
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
|
||||
return new_projection
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
batch, channel_dim, seq = hidden_states.shape
|
||||
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
|
||||
|
||||
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
|
||||
attention_probs = torch.softmax(attention_scores, dim=-1)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.matmul(attention_probs, value_states)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ResConvBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
|
||||
super().__init__()
|
||||
self.is_last = is_last
|
||||
self.has_conv_skip = in_channels != out_channels
|
||||
|
||||
if self.has_conv_skip:
|
||||
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
|
||||
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
|
||||
self.gelu_1 = nn.GELU()
|
||||
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
|
||||
|
||||
if not self.is_last:
|
||||
self.group_norm_2 = nn.GroupNorm(1, out_channels)
|
||||
self.gelu_2 = nn.GELU()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
|
||||
|
||||
hidden_states = self.conv_1(hidden_states)
|
||||
hidden_states = self.group_norm_1(hidden_states)
|
||||
hidden_states = self.gelu_1(hidden_states)
|
||||
hidden_states = self.conv_2(hidden_states)
|
||||
|
||||
if not self.is_last:
|
||||
hidden_states = self.group_norm_2(hidden_states)
|
||||
hidden_states = self.gelu_2(hidden_states)
|
||||
|
||||
output = hidden_states + residual
|
||||
return output
|
||||
|
||||
|
||||
class UNetMidBlock1D(nn.Module):
|
||||
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnDownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1D(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
self.down = Downsample1d("cubic")
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = self.down(hidden_states)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class DownBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
||||
hidden_states = torch.cat([hidden_states, temb], dim=1)
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states, (hidden_states,)
|
||||
|
||||
|
||||
class AttnUpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = out_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
attentions = [
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(mid_channels, mid_channels // 32),
|
||||
SelfAttention1d(out_channels, out_channels // 32),
|
||||
]
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states)
|
||||
hidden_states = attn(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock1D(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = Upsample1d(kernel="cubic")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
hidden_states = self.up(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock1DNoSkip(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
|
||||
super().__init__()
|
||||
mid_channels = in_channels if mid_channels is None else mid_channels
|
||||
|
||||
resnets = [
|
||||
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, mid_channels),
|
||||
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
|
||||
]
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
|
||||
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
|
||||
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
|
||||
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
|
||||
|
||||
|
||||
def get_down_block(
|
||||
down_block_type: str,
|
||||
num_layers: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
add_downsample: bool,
|
||||
) -> DownBlockType:
|
||||
if down_block_type == "DownResnetBlock1D":
|
||||
return DownResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(
|
||||
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
|
||||
) -> UpBlockType:
|
||||
if up_block_type == "UpResnetBlock1D":
|
||||
return UpResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
elif up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_mid_block(
|
||||
mid_block_type: str,
|
||||
num_layers: int,
|
||||
in_channels: int,
|
||||
mid_channels: int,
|
||||
out_channels: int,
|
||||
embed_dim: int,
|
||||
add_downsample: bool,
|
||||
) -> MidBlockType:
|
||||
if mid_block_type == "MidResTemporalBlock1D":
|
||||
return MidResTemporalBlock1D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif mid_block_type == "ValueFunctionMidBlock1D":
|
||||
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
|
||||
elif mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_out_block(
|
||||
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
|
||||
) -> Optional[OutBlockType]:
|
||||
if out_block_type == "OutConv1DBlock":
|
||||
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
||||
elif out_block_type == "ValueFunction":
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
|
||||
return None
|
||||
@@ -0,0 +1,346 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`UNet2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output from the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
|
||||
1)`.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip sin to cos for Fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
|
||||
Tuple of downsample block types.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
||||
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
|
||||
Tuple of block output channels.
|
||||
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
||||
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
||||
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
||||
downsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The downsample type for downsampling layers. Choose between "conv" and "resnet"
|
||||
upsample_type (`str`, *optional*, defaults to `conv`):
|
||||
The upsample type for upsampling layers. Choose between "conv" and "resnet"
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
|
||||
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
|
||||
given number of groups. If left as `None`, the group norm layer will only be created if
|
||||
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
|
||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to `None`):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
|
||||
conditioning with `class_embed_type` equal to `None`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
center_input_sample: bool = False,
|
||||
time_embedding_type: str = "positional",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
||||
layers_per_block: int = 2,
|
||||
mid_block_scale_factor: float = 1,
|
||||
downsample_padding: int = 1,
|
||||
downsample_type: str = "conv",
|
||||
upsample_type: str = "conv",
|
||||
dropout: float = 0.0,
|
||||
act_fn: str = "silu",
|
||||
attention_head_dim: Optional[int] = 8,
|
||||
norm_num_groups: int = 32,
|
||||
attn_norm_num_groups: Optional[int] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
add_attention: bool = True,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
# input
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
if time_embedding_type == "fourier":
|
||||
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
elif time_embedding_type == "learned":
|
||||
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=dropout,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_groups=attn_norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when doing class conditioning")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
elif self.class_embedding is None and class_labels is not None:
|
||||
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
|
||||
|
||||
# 2. pre-process
|
||||
skip_sample = sample
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "skip_conv"):
|
||||
sample, res_samples, skip_sample = downsample_block(
|
||||
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 5. up
|
||||
skip_sample = None
|
||||
for upsample_block in self.up_blocks:
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if hasattr(upsample_block, "skip_conv"):
|
||||
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
||||
else:
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if skip_sample is not None:
|
||||
sample += skip_sample
|
||||
|
||||
if self.config.time_embedding_type == "fourier":
|
||||
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
||||
sample = sample / timesteps
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DOutput(sample=sample)
|
||||
File diff suppressed because it is too large
Load Diff
+2
-2
@@ -15,8 +15,8 @@
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from .attention_flax import FlaxTransformer2DModel
|
||||
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
from ..attention_flax import FlaxTransformer2DModel
|
||||
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
|
||||
|
||||
class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
File diff suppressed because it is too large
Load Diff
+7
-7
@@ -19,10 +19,10 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from ..configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from .modeling_flax_utils import FlaxModelMixin
|
||||
from ...configuration_utils import ConfigMixin, flax_register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
||||
from ..modeling_flax_utils import FlaxModelMixin
|
||||
from .unet_2d_blocks_flax import (
|
||||
FlaxCrossAttnDownBlock2D,
|
||||
FlaxCrossAttnUpBlock2D,
|
||||
@@ -342,14 +342,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||
A tensor that if specified is added to the residual of the middle unet block.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
plain tuple.
|
||||
train (`bool`, *optional*, defaults to `False`):
|
||||
Use deterministic functions and disable dropout when not training.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. time
|
||||
+7
-7
@@ -17,19 +17,19 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils import is_torch_version
|
||||
from ..utils.torch_utils import apply_freeu
|
||||
from .attention import Attention
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .resnet import (
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import Attention
|
||||
from ..dual_transformer_2d import DualTransformer2DModel
|
||||
from ..resnet import (
|
||||
Downsample2D,
|
||||
ResnetBlock2D,
|
||||
SpatioTemporalResBlock,
|
||||
TemporalConvLayer,
|
||||
Upsample2D,
|
||||
)
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import (
|
||||
from ..transformer_2d import Transformer2DModel
|
||||
from ..transformer_temporal import (
|
||||
TransformerSpatioTemporalModel,
|
||||
TransformerTemporalModel,
|
||||
)
|
||||
+15
-15
@@ -20,20 +20,20 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, deprecate, logging
|
||||
from .activations import get_activation
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, deprecate, logging
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformer_temporal import TransformerTemporalModel
|
||||
from .unet_3d_blocks import (
|
||||
CrossAttnDownBlock3D,
|
||||
CrossAttnUpBlock3D,
|
||||
@@ -284,7 +284,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -308,7 +308,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -374,7 +374,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -449,7 +449,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -469,7 +469,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
@@ -494,7 +494,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
@@ -503,7 +503,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
|
||||
def unload_lora(self):
|
||||
"""Unloads LoRA weights."""
|
||||
deprecate(
|
||||
+5
-5
@@ -19,11 +19,11 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import Attention, AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
+14
-14
@@ -17,19 +17,19 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import logging
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformer_temporal import TransformerTemporalModel
|
||||
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_blocks import (
|
||||
@@ -524,7 +524,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -548,7 +548,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -583,7 +583,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
@@ -613,7 +613,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
||||
def disable_forward_chunking(self) -> None:
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
@@ -625,7 +625,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self) -> None:
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -645,7 +645,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
@@ -670,7 +670,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
def disable_freeu(self) -> None:
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
+7
-7
@@ -4,12 +4,12 @@ from typing import Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
@@ -20,20 +20,20 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import PeftAdapterMixin
|
||||
from .attention import BasicTransformerBlock, SkipFFTransformerBlock
|
||||
from .attention_processor import (
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from .modeling_utils import ModelMixin
|
||||
from .normalization import GlobalResponseNorm, RMSNorm
|
||||
from .resnet import Downsample2D, Upsample2D
|
||||
from ..embeddings import TimestepEmbedding, get_timestep_embedding
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import GlobalResponseNorm, RMSNorm
|
||||
from ..resnet import Downsample2D, Upsample2D
|
||||
|
||||
|
||||
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
@@ -213,7 +213,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
return logits
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -237,7 +237,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -272,7 +272,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -109,7 +109,10 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
|
||||
_import_structure["animatediff"] = ["AnimateDiffPipeline"]
|
||||
_import_structure["animatediff"] = [
|
||||
"AnimateDiffPipeline",
|
||||
"AnimateDiffVideoToVideoPipeline",
|
||||
]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
@@ -341,7 +344,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
|
||||
from .animatediff import AnimateDiffPipeline
|
||||
from .animatediff import AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import (
|
||||
AudioLDM2Pipeline,
|
||||
|
||||
@@ -11,7 +11,7 @@ from ...utils import (
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
_import_structure = {"pipeline_output": ["AnimateDiffPipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -21,7 +21,8 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline", "AnimateDiffPipelineOutput"]
|
||||
_import_structure["pipeline_animatediff"] = ["AnimateDiffPipeline"]
|
||||
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -31,7 +32,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
|
||||
else:
|
||||
from .pipeline_animatediff import AnimateDiffPipeline, AnimateDiffPipelineOutput
|
||||
from .pipeline_animatediff import AnimateDiffPipeline
|
||||
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -26,7 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unet_motion_model import MotionAdapter
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
@@ -37,7 +36,6 @@ from ...schedulers import (
|
||||
)
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
@@ -46,6 +44,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -67,10 +66,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -79,6 +75,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -147,11 +152,6 @@ def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> tor
|
||||
return x_mixed
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimateDiffPipelineOutput(BaseOutput):
|
||||
frames: Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
@@ -805,11 +805,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
@@ -0,0 +1,969 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import imageio
|
||||
>>> import requests
|
||||
>>> import torch
|
||||
>>> from diffusers import AnimateDiffVideoToVideoPipeline, DDIMScheduler, MotionAdapter
|
||||
>>> from diffusers.utils import export_to_gif
|
||||
>>> from io import BytesIO
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
|
||||
>>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter).to("cuda")
|
||||
>>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
|
||||
|
||||
>>> def load_video(file_path: str):
|
||||
... images = []
|
||||
...
|
||||
... if file_path.startswith(('http://', 'https://')):
|
||||
... # If the file_path is a URL
|
||||
... response = requests.get(file_path)
|
||||
... response.raise_for_status()
|
||||
... content = BytesIO(response.content)
|
||||
... vid = imageio.get_reader(content)
|
||||
... else:
|
||||
... # Assuming it's a local file path
|
||||
... vid = imageio.get_reader(file_path)
|
||||
...
|
||||
... for frame in vid:
|
||||
... pil_image = Image.fromarray(frame)
|
||||
... images.append(pil_image)
|
||||
...
|
||||
... return images
|
||||
|
||||
>>> video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
|
||||
>>> output = pipe(video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5)
|
||||
>>> frames = output.frames[0]
|
||||
>>> export_to_gif(frames, "animation.gif")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class AnimateDiffVideoToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for video-to-video generation.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
A [`~transformers.CLIPTokenizer`] to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
|
||||
motion_adapter ([`MotionAdapter`]):
|
||||
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
||||
_optional_components = ["feature_extractor", "image_encoder"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
motion_adapter: MotionAdapter,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
):
|
||||
super().__init__()
|
||||
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
motion_adapter=motion_adapter,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
# the tuple to access the hidden states from the desired layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
|
||||
# We also need to apply the final LayerNorm here to not mess with the
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
image = self.vae.decode(latents).sample
|
||||
video = (
|
||||
image[None, :]
|
||||
.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_frames,
|
||||
-1,
|
||||
)
|
||||
+ image.shape[2:]
|
||||
)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
strength,
|
||||
height,
|
||||
width,
|
||||
video=None,
|
||||
latents=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if video is not None and latents is not None:
|
||||
raise ValueError("Only one of `video` or `latents` should be provided")
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video,
|
||||
height,
|
||||
width,
|
||||
num_channels_latents,
|
||||
batch_size,
|
||||
timestep,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# video must be a list of list of images
|
||||
# the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
|
||||
# as a list of images
|
||||
if not isinstance(video[0], list):
|
||||
video = [video]
|
||||
if latents is None:
|
||||
video = torch.cat(
|
||||
[self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0
|
||||
)
|
||||
video = video.to(device=device, dtype=dtype)
|
||||
num_frames = video.shape[1]
|
||||
else:
|
||||
num_frames = latents.shape[2]
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
video = video.float()
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i]), generator=generator[i]).unsqueeze(0)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(vid), generator=generator).unsqueeze(0) for vid in video
|
||||
]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
|
||||
# restore vae to original dtype
|
||||
if self.vae.config.force_upcast:
|
||||
self.vae.to(dtype)
|
||||
|
||||
init_latents = init_latents.to(dtype)
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
error_message = (
|
||||
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
|
||||
" images (`image`). Please make sure to update your script to pass as many initial images as text prompts"
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(init_latents, noise, timestep).permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
if shape != latents.shape:
|
||||
# [B, C, F, H, W]
|
||||
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
video: List[List[PipelineImageInput]] = None,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
strength: float = 0.8,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
video (`List[PipelineImageInput]`):
|
||||
The input video to condition the generation on. Must be a list of images/frames of the video.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated video.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
||||
expense of slower inference.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Higher strength leads to more differences between original video and generated video.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
A higher guidance scale value encourages the model to generate images closely linked to the text
|
||||
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
|
||||
`(batch_size, num_channel, num_frames, height, width)`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*):
|
||||
Optional image input to work with IP Adapters.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
|
||||
`np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`AnimateDiffPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
num_videos_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
strength=strength,
|
||||
height=height,
|
||||
width=width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
video=video,
|
||||
latents=latents,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_videos_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
num_channels_latents=num_channels_latents,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
timestep=latent_timestep,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
# 9. Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return AnimateDiffPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimateDiffPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for AnimateDiff pipelines.
|
||||
|
||||
Args:
|
||||
frames (`List[List[PIL.Image.Image]]` or `torch.Tensor` or `np.ndarray`):
|
||||
List of PIL Images of length `batch_size` or torch.Tensor or np.ndarray of shape
|
||||
`(batch_size, num_frames, height, width, num_channels)`.
|
||||
"""
|
||||
|
||||
frames: Union[List[List[PIL.Image.Image]], torch.Tensor, np.ndarray]
|
||||
@@ -36,8 +36,8 @@ from ...models.embeddings import (
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from ...models.transformer_2d import Transformer2DModel
|
||||
from ...models.unet_2d_blocks import DownBlock2D, UpBlock2D
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
|
||||
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import BaseOutput, is_torch_version, logging
|
||||
|
||||
|
||||
@@ -513,7 +513,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -537,7 +537,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -572,7 +572,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
@@ -588,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -654,7 +654,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
@@ -687,7 +687,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
@@ -700,8 +700,8 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
a `tuple` is returned where the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
|
||||
@@ -683,9 +683,11 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
@@ -693,6 +695,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
):
|
||||
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
@@ -736,6 +739,19 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
if isinstance(self.controlnet, MultiControlNetModel):
|
||||
@@ -862,7 +878,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
||||
def prepare_control_image(
|
||||
self,
|
||||
image,
|
||||
@@ -872,10 +887,14 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
crops_coords,
|
||||
resize_mode,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
image = self.control_image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
).to(dtype=torch.float32)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
@@ -1074,6 +1093,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
control_image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
strength: float = 1.0,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
@@ -1130,6 +1150,12 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
padding_mask_crop (`int`, *optional*, defaults to `None`):
|
||||
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
|
||||
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
|
||||
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
|
||||
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
|
||||
and contain information inreleant for inpainging, such as background.
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
@@ -1240,9 +1266,11 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
control_image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
@@ -1250,6 +1278,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -1264,6 +1293,14 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
height, width = self.image_processor.get_default_height_width(image, height, width)
|
||||
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
@@ -1315,6 +1352,8 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
@@ -1330,6 +1369,8 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
@@ -1341,10 +1382,15 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
assert False
|
||||
|
||||
# 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
original_image = image
|
||||
init_image = self.image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
mask = self.mask_processor.preprocess(
|
||||
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
||||
)
|
||||
|
||||
masked_image = init_image * (mask < 0.5)
|
||||
_, _, height, width = init_image.shape
|
||||
@@ -1534,6 +1580,9 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
@@ -557,9 +557,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
prompt,
|
||||
prompt_2,
|
||||
image,
|
||||
mask_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
@@ -570,6 +572,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
@@ -632,6 +635,19 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
@@ -745,10 +761,14 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
crops_coords,
|
||||
resize_mode,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
image = self.control_image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
).to(dtype=torch.float32)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
@@ -1066,6 +1086,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
strength: float = 0.9999,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_start: Optional[float] = None,
|
||||
@@ -1121,6 +1142,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
padding_mask_crop (`int`, *optional*, defaults to `None`):
|
||||
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
|
||||
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
|
||||
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
|
||||
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
|
||||
and contain information inreleant for inpainging, such as background.
|
||||
strength (`float`, *optional*, defaults to 0.9999):
|
||||
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
|
||||
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
|
||||
@@ -1290,9 +1317,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
prompt,
|
||||
prompt_2,
|
||||
control_image,
|
||||
mask_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
@@ -1303,6 +1332,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -1370,7 +1400,18 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
# 5. Preprocess mask and image - resizes image and mask w.r.t height and width
|
||||
# 5.1 Prepare init image
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
if padding_mask_crop is not None:
|
||||
height, width = self.image_processor.get_default_height_width(image, height, width)
|
||||
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
original_image = image
|
||||
init_image = self.image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# 5.2 Prepare control images
|
||||
@@ -1383,6 +1424,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
@@ -1398,6 +1441,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
@@ -1409,7 +1454,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
raise ValueError(f"{controlnet.__class__} is not supported.")
|
||||
|
||||
# 5.3 Prepare mask
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
mask = self.mask_processor.preprocess(
|
||||
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
||||
)
|
||||
|
||||
masked_image = init_image * (mask < 0.5)
|
||||
_, _, height, width = init_image.shape
|
||||
@@ -1684,6 +1731,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
@@ -1404,11 +1404,6 @@ class StableDiffusionXLControlNetPipeline(
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# manually for max memory savings
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
|
||||
@@ -33,7 +33,7 @@ from ....models.embeddings import (
|
||||
)
|
||||
from ....models.resnet import ResnetBlockCondNorm2D
|
||||
from ....models.transformer_2d import Transformer2DModel
|
||||
from ....models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ....utils.torch_utils import apply_freeu
|
||||
|
||||
@@ -268,7 +268,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
return objs
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
|
||||
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
||||
@@ -1096,7 +1096,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
@@ -1112,8 +1112,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
a `tuple` is returned where the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
@@ -1786,7 +1786,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
class UpBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1897,7 +1897,7 @@ class UpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
class CrossAttnUpBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2071,7 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlat(nn.Module):
|
||||
"""
|
||||
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
|
||||
@@ -2227,7 +2227,7 @@ class UNetMidBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2374,7 +2374,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -351,7 +351,7 @@ def get_class_obj_and_candidates(
|
||||
|
||||
def _get_pipeline_class(
|
||||
class_obj,
|
||||
config,
|
||||
config=None,
|
||||
load_connected_pipeline=False,
|
||||
custom_pipeline=None,
|
||||
repo_id=None,
|
||||
@@ -389,7 +389,12 @@ def _get_pipeline_class(
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
class_name = config["_class_name"]
|
||||
class_name = class_name or config["_class_name"]
|
||||
if not class_name:
|
||||
raise ValueError(
|
||||
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
|
||||
)
|
||||
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
pipeline_cls = getattr(diffusers_module, class_name)
|
||||
|
||||
@@ -1317,6 +1317,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
|
||||
if config_url is not None:
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
else:
|
||||
with open(original_config_file, "r") as f:
|
||||
original_config_file = f.read()
|
||||
|
||||
original_config = yaml.safe_load(original_config_file)
|
||||
|
||||
|
||||
@@ -642,6 +642,7 @@ class StableDiffusionInpaintPipeline(
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
@@ -693,11 +694,6 @@ class StableDiffusionInpaintPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if padding_mask_crop is not None:
|
||||
if self.unet.config.in_channels != 4:
|
||||
raise ValueError(
|
||||
f"The UNet should have 4 input channels for inpainting mask crop, but has"
|
||||
f" {self.unet.config.in_channels} input channels."
|
||||
)
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
@@ -707,6 +703,8 @@ class StableDiffusionInpaintPipeline(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -1166,6 +1164,7 @@ class StableDiffusionInpaintPipeline(
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
|
||||
+44
-2
@@ -744,15 +744,19 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
@@ -810,6 +814,18 @@ class StableDiffusionXLInpaintPipeline(
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -1225,6 +1241,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
masked_image_latents: torch.FloatTensor = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
strength: float = 0.9999,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
@@ -1287,6 +1304,12 @@ class StableDiffusionXLInpaintPipeline(
|
||||
Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
and checkpoints that are not specifically fine-tuned on low resolutions.
|
||||
padding_mask_crop (`int`, *optional*, defaults to `None`):
|
||||
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
|
||||
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
|
||||
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
|
||||
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
|
||||
and contain information inreleant for inpainging, such as background.
|
||||
strength (`float`, *optional*, defaults to 0.9999):
|
||||
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
|
||||
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
|
||||
@@ -1449,15 +1472,19 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -1527,10 +1554,22 @@ class StableDiffusionXLInpaintPipeline(
|
||||
is_strength_max = strength == 1.0
|
||||
|
||||
# 5. Preprocess mask and image
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
if padding_mask_crop is not None:
|
||||
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
original_image = image
|
||||
init_image = self.image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
mask = self.mask_processor.preprocess(
|
||||
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
||||
)
|
||||
|
||||
if masked_image_latents is not None:
|
||||
masked_image = masked_image_latents
|
||||
@@ -1791,6 +1830,9 @@ class StableDiffusionXLInpaintPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
@@ -40,10 +40,8 @@ def _append_dims(x, target_dims):
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -53,7 +51,13 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
return np.stack(outputs)
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -58,22 +59,26 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
# unnormalize back to [0,1]
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
# prepare the final outputs
|
||||
i, c, f, h, w = video.shape
|
||||
images = video.permute(2, 3, 0, 4, 1).reshape(
|
||||
f, h, i * w, c
|
||||
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
||||
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
||||
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
||||
return images
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
@@ -122,6 +127,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -717,11 +723,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
+26
-21
@@ -20,6 +20,7 @@ import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -93,22 +94,26 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
# unnormalize back to [0,1]
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
# prepare the final outputs
|
||||
i, c, f, h, w = video.shape
|
||||
images = video.permute(2, 3, 0, 4, 1).reshape(
|
||||
f, h, i * w, c
|
||||
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
||||
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
||||
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
||||
return images
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def preprocess_video(video):
|
||||
@@ -198,6 +203,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -812,12 +818,11 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
if output_type == "latent":
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor)
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -752,7 +752,7 @@ class UTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
cross_attention_kwargs (*optional*):
|
||||
Keyword arguments to supply to the cross attention layers, if used.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
|
||||
ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
|
||||
|
||||
@@ -66,7 +66,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
self.set_default_attn_processor()
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
@@ -90,7 +90,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
@@ -125,7 +125,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
|
||||
@@ -61,6 +61,7 @@ else:
|
||||
_import_structure["scheduling_lcm"] = ["LCMScheduler"]
|
||||
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
|
||||
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
|
||||
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
|
||||
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
|
||||
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
|
||||
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
|
||||
@@ -152,6 +153,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .scheduling_lcm import LCMScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_repaint import RePaintScheduler
|
||||
from .scheduling_sasolver import SASolverScheduler
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_unclip import UnCLIPScheduler
|
||||
from .scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||
|
||||
@@ -98,7 +98,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.custom_timesteps = False
|
||||
self.is_scale_input_called = False
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
@@ -231,7 +231,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Modified _convert_to_karras implementation that takes in ramp as argument
|
||||
def _convert_to_karras(self, ramp):
|
||||
|
||||
@@ -187,7 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -255,7 +255,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -128,6 +128,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
|
||||
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
|
||||
`lambda(t)`.
|
||||
final_sigmas_type (`str`, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
|
||||
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||
@@ -165,11 +168,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
euler_at_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_lu_lambdas: Optional[bool] = False,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -207,6 +215,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
|
||||
)
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
@@ -214,7 +227,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -267,17 +280,24 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
elif self.config.use_lu_lambdas:
|
||||
lambdas = np.flip(log_sigmas.copy())
|
||||
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
|
||||
sigmas = np.exp(lambdas)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
||||
)
|
||||
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
|
||||
@@ -291,7 +311,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
@@ -831,7 +851,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Improve numerical stability for small number of steps
|
||||
lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
|
||||
self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
|
||||
self.config.euler_at_final
|
||||
or (self.config.lower_order_final and len(self.timesteps) < 15)
|
||||
or self.config.final_sigmas_type == "zero"
|
||||
)
|
||||
lower_order_second = (
|
||||
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||
|
||||
@@ -165,6 +165,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -209,7 +213,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
@property
|
||||
@@ -290,7 +294,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
@@ -783,7 +787,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = step_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
|
||||
@@ -198,7 +198,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.noise_sampler = None
|
||||
self.noise_sampler_seed = noise_sampler_seed
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -348,7 +348,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.mid_point_sigma = None
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.noise_sampler = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
|
||||
@@ -108,7 +108,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||
`algorithm_type="dpmsolver++"`.
|
||||
algorithm_type (`str`, defaults to `dpmsolver++`):
|
||||
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
||||
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The
|
||||
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
||||
paper, and the `dpmsolver++` type implements the algorithms in the
|
||||
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
||||
@@ -122,6 +122,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
|
||||
is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||
@@ -150,9 +153,14 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
):
|
||||
if algorithm_type == "dpmsolver":
|
||||
deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -189,6 +197,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}")
|
||||
|
||||
if algorithm_type != "dpmsolver++" and final_sigmas_type == "zero":
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
|
||||
)
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
@@ -197,7 +210,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sample = None
|
||||
self.order_list = self.get_order_list(num_train_timesteps)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
||||
"""
|
||||
@@ -267,11 +280,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
||||
else:
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f" `final_sigmas_type` must be one of `sigma_min` or `zero`, but got {self.config.final_sigmas_type}"
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
|
||||
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
@@ -285,11 +305,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
self.register_to_config(lower_order_final=True)
|
||||
|
||||
if not self.config.lower_order_final and self.config.final_sigmas_type == "zero":
|
||||
logger.warn(
|
||||
" `last_sigmas_type='zero'` is not supported for `lower_order_final=False`. Changing scheduler {self.config} to have `lower_order_final` set to True."
|
||||
)
|
||||
self.register_to_config(lower_order_final=True)
|
||||
|
||||
self.order_list = self.get_order_list(num_inference_steps)
|
||||
|
||||
# add an index counter for schedulers that allow duplicated timesteps
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
@@ -216,7 +216,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
@@ -300,7 +300,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
|
||||
@@ -237,7 +237,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
@@ -342,7 +342,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
|
||||
@@ -148,7 +148,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
@@ -270,7 +270,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.dt = None
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep)
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
|
||||
@@ -140,7 +140,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -300,7 +300,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
|
||||
@@ -140,7 +140,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -285,7 +285,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._index_counter = defaultdict(int)
|
||||
|
||||
self._step_index = None
|
||||
self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user