Compare commits

..

16 Commits

Author SHA1 Message Date
thomasw21 c11d4b42a7 Context has its own sequence length 2022-11-22 12:26:59 +01:00
thomasw21 d5af4fd153 Woops 2022-11-22 12:21:59 +01:00
thomasw21 1f135ac219 Woops 2022-11-22 12:20:26 +01:00
thomasw21 0c49f4cf30 Woops 2022-11-22 12:19:15 +01:00
thomasw21 5d4145cfa2 Remove transpose for baddbmm 2022-11-22 12:17:11 +01:00
thomasw21 31d26872c1 Revert "Making hidden_state contiguous before applying multiple linear layers"
This reverts commit 1cd09cccf3.
2022-11-22 12:07:54 +01:00
thomasw21 1cd09cccf3 Making hidden_state contiguous before applying multiple linear layers 2022-11-22 11:55:03 +01:00
thomasw21 fa4d738cbb Revert "Save one more copy" as it's much slower on A100
This reverts commit 136f84283c.
2022-11-22 11:53:58 +01:00
thomasw21 136f84283c Save one more copy 2022-11-22 11:49:15 +01:00
thomasw21 42ba85998f scatter_ argument is not called src, but rather value 2022-11-22 01:11:18 +01:00
thomasw21 e1623e2081 Woops 2022-11-22 01:02:24 +01:00
thomasw21 fdef40ba03 Woops 2022-11-22 00:57:19 +01:00
thomasw21 fe691feb5a Remove unused import 2022-11-22 00:52:53 +01:00
thomasw21 f2ed5d8b44 black 2022-11-22 00:48:50 +01:00
thomasw21 e43244f33a Fix transpose issue 2022-11-22 00:47:22 +01:00
thomasw21 3c45926a0e WIP: some optimizations 2022-11-22 00:37:17 +01:00
140 changed files with 889 additions and 13898 deletions
-2
View File
@@ -60,7 +60,6 @@ jobs:
run: |
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
@@ -128,7 +127,6 @@ jobs:
${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
shell: arch -arch arm64 bash {0}
-2
View File
@@ -62,7 +62,6 @@ jobs:
run: |
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
@@ -132,7 +131,6 @@ jobs:
run: |
python -m pip install -e .[quality,test,training]
python -m pip install git+https://github.com/huggingface/accelerate
python -m pip install -U git+https://github.com/huggingface/transformers
- name: Environment
run: |
-6
View File
@@ -106,16 +106,10 @@
title: "Score SDE VE"
- local: api/pipelines/stable_diffusion
title: "Stable Diffusion"
- local: api/pipelines/stable_diffusion_2
title: "Stable Diffusion 2"
- local: api/pipelines/stable_diffusion_safe
title: "Safe Stable Diffusion"
- local: api/pipelines/stochastic_karras_ve
title: "Stochastic Karras VE"
- local: api/pipelines/dance_diffusion
title: "Dance Diffusion"
- local: api/pipelines/versatile_diffusion
title: "Versatile Diffusion"
- local: api/pipelines/vq_diffusion
title: "VQ Diffusion"
- local: api/pipelines/repaint
+1 -1
View File
@@ -51,7 +51,7 @@ To use a different scheduler, you can either change it via the [`ConfigMixin.fro
```
- *How to convert all use cases with multiple or single pipeline*
- *How to conver all use cases with multiple or single pipeline*
If you want to use all possible use cases in a single `DiffusionPipeline` we recommend using the `components` functionality to instantiate all components in the most memory-efficient way:
+1 -8
View File
@@ -58,14 +58,7 @@ available a colab notebook to directly try them out.
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
+1 -26
View File
@@ -48,7 +48,7 @@ To use a different scheduler, you can either change it via the [`ConfigMixin.fro
```
### How to convert all use cases with multiple or single pipeline
### How to conver all use cases with multiple or single pipeline
If you want to use all possible use cases in a single `DiffusionPipeline` you can either:
- Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or
@@ -76,40 +76,15 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_vae_slicing
- disable_vae_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## StableDiffusionImg2ImgPipeline
[[autodoc]] StableDiffusionImg2ImgPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## StableDiffusionInpaintPipeline
[[autodoc]] StableDiffusionInpaintPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## StableDiffusionImageVariationPipeline
[[autodoc]] StableDiffusionImageVariationPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
## StableDiffusionUpscalePipeline
[[autodoc]] StableDiffusionUpscalePipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
- enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
@@ -1,142 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable diffusion 2
Stable Diffusion 2 is a text-to-image _latent diffusion_ model built upon the work of [Stable Diffusion 1](https://stability.ai/blog/stable-diffusion-public-release).
The project to train Stable Diffusion 2 was led by Robin Rombach and Katherine Crowson from [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/).
*The Stable Diffusion 2.0 release includes robust text-to-image models trained using a brand new text encoder (OpenCLIP), developed by LAION with support from Stability AI, which greatly improves the quality of the generated images compared to earlier V1 releases. The text-to-image models in this release can generate images with default resolutions of both 512x512 pixels and 768x768 pixels.
These models are trained on an aesthetic subset of the [LAION-5B dataset](https://laion.ai/blog/laion-5b/) created by the DeepFloyd team at Stability AI, which is then further filtered to remove adult content using [LAIONs NSFW filter](https://openreview.net/forum?id=M3Y74vmsMcY).*
For more details about how Stable Diffusion 2 works and how it differs from Stable Diffusion 1, please refer to the official [launch announcement post](https://stability.ai/blog/stable-diffusion-v2-release).
## Tips
### Available checkpoints:
Note that the architecture is more or less identical to [Stable Diffusion 1](./api/pipelines/stable_diffusion) so please refer to [this page](./api/pipelines/stable_diffusion) for API documentation.
- *Text-to-Image (512x512 resolution)*: [stabilityai/stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) with [`StableDiffusionPipeline`]
- *Text-to-Image (768x768 resolution)*: [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) with [`StableDiffusionPipeline`]
- *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`]
- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`]
We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest scheduler there is.
- *Text-to-Image (512x512 resolution)*:
```python
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
repo_id = "stabilityai/stable-diffusion-2-base"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "High quality photo of an astronaut riding a horse in space"
image = pipe(prompt, num_inference_steps=25).images[0]
image.save("astronaut.png")
```
- *Text-to-Image (768x768 resolution)*:
```python
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
repo_id = "stabilityai/stable-diffusion-2"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "High quality photo of an astronaut riding a horse in space"
image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0]
image.save("astronaut.png")
```
- *Image Inpainting (512x512 resolution)*:
```python
import PIL
import requests
import torch
from io import BytesIO
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
repo_id = "stabilityai/stable-diffusion-2-inpainting"
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=25).images[0]
image.save("yellow_cat.png")
```
- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`]
```python
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionUpscalePipeline
import torch
# load model and scheduler
model_id = "stabilityai/stable-diffusion-x4-upscaler"
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")
# let's download an image
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
response = requests.get(url)
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
low_res_img = low_res_img.resize((128, 128))
prompt = "a white cat"
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
upscaled_image.save("upsampled_cat.png")
```
### How to load and use different schedulers.
The stable diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=euler_scheduler)
```
@@ -1,90 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Safe Stable Diffusion
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://arxiv.org/abs/2211.05105) and mitigates the well known issue that models like Stable Diffusion that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, or otherwise offensive content.
Safe Stable Diffusion is an extension to the Stable Diffusion that drastically reduces content like this.
The abstract of the paper is the following:
*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.*
*Overview*:
| Pipeline | Tasks | Colab | Demo
|---|---|:---:|:---:|
| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | -
## Tips
- Safe Stable Diffusion may also be used with weights of [Stable Diffusion](./api/pipelines/stable_diffusion).
### Run Safe Stable Diffusion
Safe Stable Diffusion can be tested very easily with the [`StableDiffusionPipelineSafe`], and the `"AIML-TUDA/stable-diffusion-safe"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation).
### Interacting with the Safety Concept
To check and edit the currently used safety concept, use the `safety_concept` property of [`StableDiffusionPipelineSafe`]
```python
>>> from diffusers import StableDiffusionPipelineSafe
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> pipeline.safety_concept
```
For each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`].
### Using pre-defined safety configurations
You may use the 4 configurations defined in the [Safe Latent Diffusion paper](https://arxiv.org/abs/2211.05105) as follows:
```python
>>> from diffusers import StableDiffusionPipelineSafe
>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
```
The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONg`, and `SafetyConfig.MAX`.
### How to load and use different schedulers.
The safe stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import StableDiffusionPipelineSafe, EulerDiscreteScheduler
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("AIML-TUDA/stable-diffusion-safe", subfolder="scheduler")
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained(
... "AIML-TUDA/stable-diffusion-safe", scheduler=euler_scheduler
... )
```
## StableDiffusionSafePipelineOutput
[[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput
## StableDiffusionPipelineSafe
[[autodoc]] StableDiffusionPipelineSafe
- __call__
- enable_attention_slicing
- disable_attention_slicing
@@ -1,73 +0,0 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# VersatileDiffusion
VersatileDiffusion was proposed in [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) by Xingqian Xu, Zhangyang Wang, Eric Zhang, Kai Wang, Humphrey Shi .
The abstract of the paper is the following:
*The recent advances in diffusion models have set an impressive milestone in many generation tasks. Trending works such as DALL-E2, Imagen, and Stable Diffusion have attracted great interest in academia and industry. Despite the rapid landscape changes, recent new approaches focus on extensions and performance rather than capacity, thus requiring separate models for separate tasks. In this work, we expand the existing single-flow diffusion pipeline into a multi-flow network, dubbed Versatile Diffusion (VD), that handles text-to-image, image-to-text, image-variation, and text-variation in one unified model. Moreover, we generalize VD to a unified multi-flow multimodal diffusion framework with grouped layers, swappable streams, and other propositions that can process modalities beyond images and text. Through our experiments, we demonstrate that VD and its underlying framework have the following merits: a) VD handles all subtasks with competitive quality; b) VD initiates novel extensions and applications such as disentanglement of style and semantic, image-text dual-guided generation, etc.; c) Through these experiments and applications, VD provides more semantic insights of the generated outputs.*
## Tips
- VersatileDiffusion is conceptually very similar as [Stable Diffusion](./api/pipelines/stable_diffusion), but instead of providing just a image data stream conditioned on text, VersatileDiffusion provides both a image and text data stream and can be conditioned on both text and image.
### *Run VersatileDiffusion*
You can both load the memory intensive "all-in-one" [`VersatileDiffusionPipeline`] that can run all tasks
with the same class as shown in [`VersatileDiffusionPipeline.text_to_image`], [`VersatileDiffusionPipeline.image_variation`], and [`VersatileDiffusionPipeline.dual_guided`]
**or**
You can run the individual pipelines which are much more memory efficient:
- *Text-to-Image*: [`VersatileDiffusionTextToImagePipeline.__call__`]
- *Image Variation*: [`VersatileDiffusionImageVariationPipeline.__call__`]
- *Dual Text and Image Guided Generation*: [`VersatileDiffusionDualGuidedPipeline.__call__`]
### *How to load and use different schedulers.*
The versatile diffusion pipelines uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
```python
>>> from diffusers import VersatileDiffusionPipeline, EulerDiscreteScheduler
>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("shi-labs/versatile-diffusion", subfolder="scheduler")
>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", scheduler=euler_scheduler)
```
## VersatileDiffusionPipeline
[[autodoc]] VersatileDiffusionPipeline
## VersatileDiffusionTextToImagePipeline
[[autodoc]] VersatileDiffusionTextToImagePipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
## VersatileDiffusionImageVariationPipeline
[[autodoc]] VersatileDiffusionImageVariationPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
## VersatileDiffusionDualGuidedPipeline
[[autodoc]] VersatileDiffusionDualGuidedPipeline
- __call__
- enable_attention_slicing
- disable_attention_slicing
+1 -27
View File
@@ -76,33 +76,6 @@ Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [im
[[autodoc]] DPMSolverMultistepScheduler
#### Heun scheduler inspired by Karras et. al paper
Algorithm 1 of [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
[[autodoc]] HeunDiscreteScheduler
#### DPM Discrete Scheduler inspired by Karras et. al paper
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
[[autodoc]] KDPM2DiscreteScheduler
#### DPM Discrete Scheduler with ancestral sampling inspired by Karras et. al paper
Inspired by [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
[[autodoc]] KDPM2AncestralDiscreteScheduler
#### Variance exploding, stochastic sampling from Karras et. al
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
@@ -113,6 +86,7 @@ Original paper can be found [here](https://arxiv.org/abs/2006.11239).
Original implementation can be found [here](https://arxiv.org/abs/2206.00364).
[[autodoc]] LMSDiscreteScheduler
#### Pseudo numerical methods for diffusion models (PNDM)
-7
View File
@@ -48,14 +48,7 @@ available a colab notebook to directly try them out.
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
-28
View File
@@ -117,34 +117,6 @@ image = pipe(prompt).images[0]
There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM!
## Sliced VAE decode for larger batches
To decode large batches of images with limited VRAM, or to enable batches with 32 images or more, you can use sliced VAE decode that decodes the batch latents one image at a time.
You likely want to couple this with [`~StableDiffusionPipeline.enable_attention_slicing`] or [`~StableDiffusionPipeline.enable_xformers_memory_efficient_attention`] to further minimize memory use.
To perform the VAE decode one image at a time, invoke [`~StableDiffusionPipeline.enable_vae_slicing`] in your pipeline before inference. For example:
```Python
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_vae_slicing()
images = pipe([prompt] * 32).images
```
You may see a small performance boost in VAE decode on multi-image batches. There should be no performance impact on single-image batches.
## Offloading to CPU with accelerate for memory savings
For additional memory savings, you can offload the weights to CPU and load them to GPU when performing the forward pass.
+18
View File
@@ -378,3 +378,21 @@ dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler"
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc`
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
```
## API
[[autodoc]] modeling_utils.ModelMixin
- from_pretrained
- save_pretrained
[[autodoc]] pipeline_utils.DiffusionPipeline
- from_pretrained
- save_pretrained
[[autodoc]] modeling_flax_utils.FlaxModelMixin
- from_pretrained
- save_pretrained
[[autodoc]] pipeline_flax_utils.FlaxDiffusionPipeline
- from_pretrained
- save_pretrained
+3 -6
View File
@@ -602,7 +602,7 @@ For example, this could be used to place a logo on a shirt and make it blend sea
import PIL
import torch
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionInpaintPipeline
image_path = "./path-to-image.png"
inner_image_path = "./path-to-inner-image.png"
@@ -612,11 +612,10 @@ init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
pipe = DiffusionPipeline.from_pretrained(
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
custom_pipeline="img2img_inpainting",
revision="fp16",
torch_dtype=torch.float16
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
@@ -624,8 +623,6 @@ prompt = "Your prompt here!"
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
```
![2 by 2 grid demonstrating image to image inpainting.](https://user-images.githubusercontent.com/44398246/203506577-ec303be4-887e-4ebd-a773-c83fcb3dd01a.png)
### Text Based Inpainting Stable Diffusion
Use a text prompt to generate the mask for the area to be inpainted.
+5 -7
View File
@@ -138,7 +138,7 @@ def ddpm_bit_scheduler_step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
prediction_type="epsilon",
predict_epsilon=True,
generator=None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
@@ -150,8 +150,8 @@ def ddpm_bit_scheduler_step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
Returns:
@@ -174,12 +174,10 @@ def ddpm_bit_scheduler_step(
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if prediction_type == "epsilon":
if predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(f"Unsupported prediction_type {prediction_type}.")
pred_original_sample = model_output
# 3. Clip "predicted x_0"
scale = self.bit_scale
@@ -78,12 +78,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
)
self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
cut_out_size = (
feature_extractor.size
if isinstance(feature_extractor.size, int)
else feature_extractor.size["shortest_edge"]
)
self.make_cutouts = MakeCutouts(cut_out_size)
self.make_cutouts = MakeCutouts(feature_extractor.size)
set_requires_grad(self.text_encoder, False)
set_requires_grad(self.clip_model, False)
+1 -1
View File
@@ -110,7 +110,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -101,7 +101,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
+1 -1
View File
@@ -469,7 +469,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -113,7 +113,7 @@ class MultilingualStableDiffusion(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -77,7 +77,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
super().__init__()
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -42,7 +42,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
super().__init__()
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
+1 -1
View File
@@ -99,7 +99,7 @@ class TextInpainting(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -135,7 +135,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
+3 -15
View File
@@ -39,8 +39,6 @@ Now let's get our dataset. Download images from [here](https://drive.google.com/
And launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
@@ -143,7 +141,7 @@ export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
accelerate launch --mixed_precision="fp16" train_dreambooth.py \
accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
@@ -159,7 +157,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
--max_train_steps=800 \
--mixed_precision=fp16
```
### Fine-tune text encoder with the UNet.
@@ -195,17 +194,6 @@ accelerate launch train_dreambooth.py \
--max_train_steps=800
```
### Using DreamBooth for other pipelines than Stable Diffusion
Altdiffusion also support dreambooth now, the runing comman is basically the same as abouve, all you need to do is replace the `MODEL_NAME` like this:
One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
```
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion-m9"
or
export MODEL_NAME="CompVis/stable-diffusion-v1-4" --> export MODEL_NAME="BAAI/AltDiffusion"
```
### Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
+20 -65
View File
@@ -14,38 +14,18 @@ from torch.utils.data import Dataset
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from huggingface_hub import HfFolder, Repository, whoami
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from transformers import CLIPTextModel, CLIPTokenizer
logger = get_logger(__name__)
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
@@ -144,7 +124,6 @@ def parse_args(input_args=None):
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
@@ -208,12 +187,12 @@ def parse_args(input_args=None):
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
@@ -377,7 +356,7 @@ def main(args):
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
pipeline = DiffusionPipeline.from_pretrained(
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
@@ -427,24 +406,19 @@ def main(args):
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer = CLIPTokenizer.from_pretrained(
args.tokenizer_name,
revision=args.revision,
use_fast=False,
)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
# Load models and create wrapper for stable diffusion
text_encoder = text_encoder_cls.from_pretrained(
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
@@ -564,9 +538,9 @@ def main(args):
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
@@ -629,31 +603,23 @@ def main(args):
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
noise, noise_prior = torch.chunk(noise, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -672,17 +638,6 @@ def main(args):
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
if accelerator.is_main_process:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
pipeline.save_pretrained(save_path)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
@@ -694,7 +649,7 @@ def main(args):
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = DiffusionPipeline.from_pretrained(
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
+4 -4
View File
@@ -42,13 +42,11 @@ If you have already cloned the repo, then you won't need to go through these ste
#### Hardware
With `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with >30GB memory.
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
@@ -56,6 +54,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
@@ -71,7 +70,7 @@ If you wish to use custom loading logic, you should modify the script, we have l
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--use_ema \
@@ -79,6 +78,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
+20 -46
View File
@@ -15,12 +15,13 @@ from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
logger = get_logger(__name__)
@@ -35,13 +36,6 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -192,12 +186,12 @@ def parse_args():
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument(
@@ -341,24 +335,10 @@ def main():
os.makedirs(args.output_dir, exist_ok=True)
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
# Freeze vae and text_encoder
vae.requires_grad_(False)
@@ -516,9 +496,9 @@ def main():
)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move text_encode and vae to gpu.
@@ -582,17 +562,9 @@ def main():
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
@@ -628,12 +600,14 @@ def main():
if args.use_ema:
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
revision=args.revision,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_dir)
-2
View File
@@ -47,8 +47,6 @@ Now let's get our dataset.Download 3-4 images from [here](https://drive.google.c
And launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATA_DIR="path-to-dir-containing-images"
+17 -44
View File
@@ -16,8 +16,9 @@ import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from huggingface_hub import HfFolder, Repository, whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
@@ -25,7 +26,7 @@ from packaging import version
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
@@ -50,11 +51,11 @@ else:
logger = get_logger(__name__)
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
def save_progress(text_encoder, placeholder_token_id, accelerator, args):
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)
torch.save(learned_embeds_dict, os.path.join(args.output_dir, "learned_embeds.bin"))
def parse_args():
@@ -72,13 +73,6 @@ def parse_args():
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -411,21 +405,9 @@ def main():
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
# Load models and create wrapper for stable diffusion
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
)
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
@@ -550,17 +532,9 @@ def main():
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
# Zero out the gradients for all token embeddings except the newly added
@@ -582,8 +556,7 @@ def main():
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
save_progress(text_encoder, placeholder_token_id, accelerator, args)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -596,18 +569,18 @@ def main():
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
pipeline = StableDiffusionPipeline(
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
vae=vae,
unet=unet,
revision=args.revision,
tokenizer=tokenizer,
scheduler=PNDMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler"),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_dir)
# Also save the newly trained embeddings
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
save_progress(text_encoder, placeholder_token_id, accelerator, args)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
@@ -194,10 +194,9 @@ def parse_args():
)
parser.add_argument(
"--prediction_type",
type=str,
default="epsilon",
choices=["epsilon", "sample"],
"--predict_epsilon",
action="store_true",
default=True,
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
)
@@ -257,13 +256,13 @@ def main(args):
"UpBlock2D",
),
)
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_prediction_type:
if accepts_predict_epsilon:
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule,
prediction_type=args.prediction_type,
predict_epsilon=args.predict_epsilon,
)
else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
@@ -320,12 +319,7 @@ def main(args):
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
ema_model = EMAModel(
accelerator.unwrap_model(model),
inv_gamma=args.ema_inv_gamma,
power=args.ema_power,
max_value=args.ema_max_decay,
)
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
# Handle the repository creation
if accelerator.is_main_process:
@@ -371,9 +365,9 @@ def main(args):
# Predict the noise residual
model_output = model(noisy_images, timesteps).sample
if args.prediction_type == "epsilon":
if args.predict_epsilon:
loss = F.mse_loss(model_output, noise) # this could have different weights!
elif args.prediction_type == "sample":
else:
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
@@ -382,8 +376,6 @@ def main(args):
model_output, clean_images, reduction="none"
) # use SNR weighting from distillation paper
loss = loss.mean()
else:
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
accelerator.backward(loss)
@@ -211,7 +211,6 @@ def create_unet_diffusers_config(original_config):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
model_params = original_config.model.params
unet_params = original_config.model.params.unet_config.params
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
@@ -231,7 +230,7 @@ def create_unet_diffusers_config(original_config):
resolution //= 2
config = dict(
sample_size=model_params.image_size,
sample_size=unet_params.image_size,
in_channels=unet_params.in_channels,
out_channels=unet_params.out_channels,
down_block_types=tuple(down_block_types),
@@ -666,29 +665,17 @@ if __name__ == "__main__":
args = parser.parse_args()
checkpoint = torch.load(args.checkpoint_path)
checkpoint = checkpoint["state_dict"]
prediction_type = "epsilon"
if args.original_config_file is None:
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
# model_type = "v2"
os.system(
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
)
args.original_config_file = "./v2-inference-v.yaml"
prediction_type
else:
# model_type = "v2"
os.system(
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
args.original_config_file = "./v1-inference.yaml"
os.system(
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
args.original_config_file = "./v1-inference.yaml"
original_config = OmegaConf.load(args.original_config_file)
checkpoint = torch.load(args.checkpoint_path)
checkpoint = checkpoint["state_dict"]
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
@@ -215,10 +215,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
)
del pipeline.safety_checker
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
feature_extractor = pipeline.feature_extractor
else:
safety_checker = None
feature_extractor = None
onnx_pipeline = OnnxStableDiffusionPipeline(
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
@@ -228,8 +226,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
scheduler=pipeline.scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=safety_checker is not None,
feature_extractor=pipeline.feature_extractor,
)
onnx_pipeline.save_pretrained(output_path)
@@ -1,791 +0,0 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the Versatile Stable Diffusion checkpoints. """
import argparse
from argparse import Namespace
import torch
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
VersatileDiffusionPipeline,
)
from diffusers.pipelines.versatile_diffusion.modeling_text_unet import UNetFlatConditionModel
from transformers import (
CLIPFeatureExtractor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
SCHEDULER_CONFIG = Namespace(
**{
"beta_linear_start": 0.00085,
"beta_linear_end": 0.012,
"timesteps": 1000,
"scale_factor": 0.18215,
}
)
IMAGE_UNET_CONFIG = Namespace(
**{
"input_channels": 4,
"model_channels": 320,
"output_channels": 4,
"num_noattn_blocks": [2, 2, 2, 2],
"channel_mult": [1, 2, 4, 4],
"with_attn": [True, True, True, False],
"num_heads": 8,
"context_dim": 768,
"use_checkpoint": True,
}
)
TEXT_UNET_CONFIG = Namespace(
**{
"input_channels": 768,
"model_channels": 320,
"output_channels": 768,
"num_noattn_blocks": [2, 2, 2, 2],
"channel_mult": [1, 2, 4, 4],
"second_dim": [4, 4, 4, 4],
"with_attn": [True, True, True, False],
"num_heads": 8,
"context_dim": 768,
"use_checkpoint": True,
}
)
AUTOENCODER_CONFIG = Namespace(
**{
"double_z": True,
"z_channels": 4,
"resolution": 256,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": [1, 2, 4, 4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0,
}
)
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif path["old"] in old_checkpoint:
checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_image_unet_diffusers_config(unet_params):
"""
Creates a config for the diffusers based on the config of the VD model.
"""
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if unet_params.with_attn[i] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if unet_params.with_attn[-i - 1] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks):
raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.")
config = dict(
sample_size=None,
in_channels=unet_params.input_channels,
out_channels=unet_params.output_channels,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_noattn_blocks[0],
cross_attention_dim=unet_params.context_dim,
attention_head_dim=unet_params.num_heads,
)
return config
def create_text_unet_diffusers_config(unet_params):
"""
Creates a config for the diffusers based on the config of the VD model.
"""
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlockFlat" if unet_params.with_attn[i] else "DownBlockFlat"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlockFlat" if unet_params.with_attn[-i - 1] else "UpBlockFlat"
up_block_types.append(block_type)
resolution //= 2
if not all(n == unet_params.num_noattn_blocks[0] for n in unet_params.num_noattn_blocks):
raise ValueError("Not all num_res_blocks are equal, which is not supported in this script.")
config = dict(
sample_size=None,
in_channels=(unet_params.input_channels, 1, 1),
out_channels=(unet_params.output_channels, 1, 1),
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_noattn_blocks[0],
cross_attention_dim=unet_params.context_dim,
attention_head_dim=unet_params.num_heads,
)
return config
def create_vae_diffusers_config(vae_params):
"""
Creates a config for the diffusers based on the config of the VD model.
"""
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = dict(
sample_size=vae_params.resolution,
in_channels=vae_params.in_channels,
out_channels=vae_params.out_ch,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
latent_channels=vae_params.z_channels,
layers_per_block=vae_params.num_res_blocks,
)
return config
def create_diffusers_scheduler(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
beta_schedule="scaled_linear",
)
return schedular
def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print("Checkpoint has both EMA and non-EMA weights.")
if extract_ema:
print(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["model.diffusion_model.time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["model.diffusion_model.time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["model.diffusion_model.time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["model.diffusion_model.time_embed.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
elif f"input_blocks.{i}.0.weight" in unet_state_dict:
# text_unet uses linear layers in place of downsamplers
shape = unet_state_dict[f"input_blocks.{i}.0.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.bias"
)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if ["conv.weight", "conv.bias"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
elif f"output_blocks.{i}.1.weight" in unet_state_dict:
# text_unet uses linear layers in place of upsamplers
shape = unet_state_dict[f"output_blocks.{i}.1.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop(
f"output_blocks.{i}.1.weight"
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop(
f"output_blocks.{i}.1.bias"
)
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
elif f"output_blocks.{i}.2.weight" in unet_state_dict:
# text_unet uses linear layers in place of upsamplers
shape = unet_state_dict[f"output_blocks.{i}.2.weight"].shape
if shape[0] != shape[1]:
continue
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.weight"] = unet_state_dict.pop(
f"output_blocks.{i}.2.weight"
)
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.bias"] = unet_state_dict.pop(
f"output_blocks.{i}.2.bias"
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = unet_state_dict[old_path]
return new_checkpoint
def convert_vd_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
keys = list(checkpoint.keys())
for key in keys:
vae_state_dict[key] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--unet_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--vae_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--optimus_checkpoint_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--scheduler_type",
default="pndm",
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
)
parser.add_argument(
"--extract_ema",
action="store_true",
help=(
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
),
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
scheduler_config = SCHEDULER_CONFIG
num_train_timesteps = scheduler_config.timesteps
beta_start = scheduler_config.beta_linear_start
beta_end = scheduler_config.beta_linear_end
if args.scheduler_type == "pndm":
scheduler = PNDMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
skip_prk_steps=True,
steps_offset=1,
)
elif args.scheduler_type == "lms":
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
elif args.scheduler_type == "euler":
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
elif args.scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
elif args.scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
elif args.scheduler_type == "ddim":
scheduler = DDIMScheduler(
beta_start=beta_start,
beta_end=beta_end,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
else:
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel models.
if args.unet_checkpoint_path is not None:
# image UNet
image_unet_config = create_image_unet_diffusers_config(IMAGE_UNET_CONFIG)
checkpoint = torch.load(args.unet_checkpoint_path)
converted_image_unet_checkpoint = convert_vd_unet_checkpoint(
checkpoint, image_unet_config, unet_key="model.diffusion_model.unet_image.", extract_ema=args.extract_ema
)
image_unet = UNet2DConditionModel(**image_unet_config)
image_unet.load_state_dict(converted_image_unet_checkpoint)
# text UNet
text_unet_config = create_text_unet_diffusers_config(TEXT_UNET_CONFIG)
converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
)
text_unet = UNetFlatConditionModel(**text_unet_config)
text_unet.load_state_dict(converted_text_unet_checkpoint)
# Convert the VAE model.
if args.vae_checkpoint_path is not None:
vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG)
checkpoint = torch.load(args.vae_checkpoint_path)
converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
image_feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
pipe = VersatileDiffusionPipeline(
scheduler=scheduler,
tokenizer=tokenizer,
image_feature_extractor=image_feature_extractor,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_unet=image_unet,
text_unet=text_unet,
vae=vae,
)
pipe.save_pretrained(args.dump_path)
-70
View File
@@ -1,70 +0,0 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+2 -4
View File
@@ -97,7 +97,6 @@ _deps = [
"pytest",
"pytest-timeout",
"pytest-xdist",
"safetensors",
"sentencepiece>=0.1.91,!=0.1.92",
"scipy",
"regex!=2019.12.17",
@@ -185,11 +184,10 @@ extras["test"] = deps_list(
"pytest",
"pytest-timeout",
"pytest-xdist",
"safetensors",
"sentencepiece",
"scipy",
"torchvision",
"transformers",
"transformers"
)
extras["torch"] = deps_list("torch", "accelerate")
@@ -214,7 +212,7 @@ install_requires = [
setup(
name="diffusers",
version="0.9.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.8.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="Diffusers",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+1 -11
View File
@@ -9,7 +9,7 @@ from .utils import (
)
__version__ = "0.9.0"
__version__ = "0.8.0.dev0"
from .configuration_utils import ConfigMixin
from .onnx_utils import OnnxRuntimeModel
@@ -46,11 +46,8 @@ if is_torch_available():
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
PNDMScheduler,
RePaintScheduler,
SchedulerMixin,
@@ -72,17 +69,10 @@ if is_torch_available() and is_transformers_available():
AltDiffusionPipeline,
CycleDiffusionPipeline,
LDMTextToImagePipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
StableDiffusionPipelineSafe,
StableDiffusionUpscalePipeline,
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
else:
+9 -28
View File
@@ -24,8 +24,6 @@ import re
from collections import OrderedDict
from typing import Any, Dict, Tuple, Union
import numpy as np
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
@@ -82,21 +80,20 @@ class ConfigMixin:
- **config_name** (`str`) -- A filename under which the config should stored when calling
[`~ConfigMixin.save_config`] (should be overridden by parent class).
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
overridden by subclass).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
subclass).
overridden by parent class).
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
class).
"""
config_name = None
ignore_for_config = []
has_compatibles = False
_deprecated_kwargs = []
def register_to_config(self, **kwargs):
if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
kwargs["_class_name"] = self.__class__.__name__
kwargs["_diffusers_version"] = __version__
# Special case for `kwargs` used in deprecation warning added to schedulers
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
# or solve in a more general way.
@@ -201,11 +198,6 @@ class ConfigMixin:
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")
# add possible deprecated kwargs
for deprecated_kwarg in cls._deprecated_kwargs:
if deprecated_kwarg in unused_kwargs:
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)
@@ -470,7 +462,7 @@ class ConfigMixin:
unused_kwargs = {**config_dict, **kwargs}
# 7. Define "hidden" config parameters that were saved for compatible classes
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")}
return init_dict, unused_kwargs, hidden_config_dict
@@ -501,15 +493,6 @@ class ConfigMixin:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
@@ -537,7 +520,7 @@ def register_to_config(init):
def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
@@ -562,9 +545,7 @@ def register_to_config(init):
if k not in ignore and k not in new_kwargs
}
)
new_kwargs = {**config_init_kwargs, **new_kwargs}
getattr(self, "register_to_config")(**new_kwargs)
init(self, *args, **init_kwargs)
return inner_init
@@ -581,7 +562,7 @@ def flax_register_to_config(cls):
)
# Ignore private kwargs in the init. Retrieve all passed attributes
init_kwargs = {k: v for k, v in kwargs.items()}
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
# Retrieve default values
fields = dataclasses.fields(self)
@@ -21,7 +21,6 @@ deps = {
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"safetensors": "safetensors",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"scipy": "scipy",
"regex": "regex!=2019.12.17",
@@ -89,7 +89,6 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
x = x + scale * grad
x = self.reset_x0(x, conditions, self.action_dim)
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
# TODO: set prediction_type when instantiating the model
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
# apply conditions to the trajectory
+1 -1
View File
@@ -332,7 +332,7 @@ class FlaxModelMixin:
elif os.path.isfile(os.path.join(pretrained_path_with_subfolder, WEIGHTS_NAME)):
raise EnvironmentError(
f"{WEIGHTS_NAME} file found in directory {pretrained_path_with_subfolder}. Please load the model"
" using `from_pt=True`."
" using `from_pt=True`."
)
else:
raise EnvironmentError(
+66 -140
View File
@@ -30,10 +30,8 @@ from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
logging,
)
@@ -53,9 +51,6 @@ if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version
if is_safetensors_available():
import safetensors
def get_parameter_device(parameter: torch.nn.Module):
try:
@@ -89,13 +84,10 @@ def get_parameter_dtype(parameter: torch.nn.Module):
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
"""
try:
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
return torch.load(checkpoint_file, map_location="cpu")
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
return torch.load(checkpoint_file, map_location="cpu")
except Exception as e:
try:
with open(checkpoint_file) as f:
@@ -112,7 +104,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
f"at '{checkpoint_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)
@@ -340,7 +332,7 @@ class ModelMixin(torch.nn.Module):
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
@@ -383,44 +375,80 @@ class ModelMixin(torch.nn.Module):
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
model_file = None
if is_safetensors_available():
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
else:
raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
)
else:
try:
model_file = _get_model_file(
# Load from URL or cache if already cached
model_file = hf_hub_download(
pretrained_model_name_or_path,
weights_name=SAFETENSORS_WEIGHTS_NAME,
filename=WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
except:
pass
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
)
except HTTPError as err:
raise EnvironmentError(
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {WEIGHTS_NAME} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {WEIGHTS_NAME}"
)
# restore default dtype
if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
config, unused_kwargs = cls.load_config(
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
@@ -434,7 +462,6 @@ class ModelMixin(torch.nn.Module):
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
if device_map is None:
@@ -455,7 +482,7 @@ class ModelMixin(torch.nn.Module):
"error_msgs": [],
}
else:
config, unused_kwargs = cls.load_config(
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
@@ -469,24 +496,8 @@ class ModelMixin(torch.nn.Module):
device_map=device_map,
**kwargs,
)
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file)
dtype = set(v.dtype for v in state_dict.values())
if len(dtype) > 1 and torch.float32 not in dtype:
raise ValueError(
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
f" make sure that {model_file} weights have only one dtype."
)
elif len(dtype) > 1 and torch.float32 in dtype:
dtype = torch.float32
else:
dtype = dtype.pop()
# move model to correct dtype
model = model.to(dtype)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
@@ -678,88 +689,3 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
return unwrap_model(model.module)
else:
return model
def _get_model_file(
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
return model_file
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
try:
# Load from URL or cache if already cached
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
return model_file
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)
+69 -224
View File
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from typing import Optional
@@ -99,11 +98,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
@@ -129,10 +125,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
@@ -158,7 +151,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
)
for d in range(num_layers)
]
@@ -166,10 +158,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 4. Define output layers
if self.is_input_continuous:
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
@@ -201,16 +190,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
@@ -220,17 +203,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
@@ -239,7 +213,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
output = F.log_softmax(logits, dim=1, dtype=torch.double).float()
if not return_dict:
return (output,)
@@ -314,55 +288,60 @@ class AttentionBlock(nn.Module):
# get scores
if self.num_heads > 1:
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
else:
query_states, key_states, value_states = query_proj, key_proj, value_proj
attention_scores = torch.baddbmm(
torch.empty(
query_states.shape[0],
query_states.shape[1],
key_states.shape[1],
dtype=query_states.dtype,
device=query_states.device,
),
query_states,
key_states.transpose(-1, -2),
beta=0,
alpha=scale,
query_states = (
self.transpose_for_scores(query_proj)
.contiguous()
.view(batch * self.num_heads, height * width, self.num_head_size)
)
key_states = (
self.transpose_for_scores(key_proj)
.transpose(3, 2)
.contiguous()
.view(batch * self.num_heads, self.num_head_size, height * width)
)
value_states = (
self.transpose_for_scores(value_proj)
.contiguous()
.view(batch * self.num_heads, height * width, self.num_head_size)
)
else:
query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
attention_scores = torch.baddbmm(
torch.empty(
query_states.shape[0],
query_states.shape[1],
key_states.shape[2],
dtype=query_states.dtype,
device=query_states.device,
),
query_states,
key_states,
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores, dim=-1, dtype=torch.float).type(attention_scores.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value_states)
if self.num_heads > 1:
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
hidden_states = (
hidden_states.view(batch, self.num_heads, height * width, self.num_head_size)
.permute(0, 2, 1, 3)
.contiguous()
)
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
else:
hidden_states = torch.bmm(attention_probs, value_states)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
hidden_states = hidden_states + residual
if self.rescale_output_factor != 1.0:
hidden_states = hidden_states / self.rescale_output_factor
return hidden_states
@@ -393,17 +372,14 @@ class BasicTransformerBlock(nn.Module):
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention(
@@ -425,16 +401,6 @@ class BasicTransformerBlock(nn.Module):
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
# if xformers is installed try to use memory_efficient_attention by default
if is_xformers_available():
try:
self._set_use_memory_efficient_attention_xformers(True)
except Exception as e:
warnings.warn(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)
def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size
@@ -470,11 +436,7 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
if self.only_cross_attention:
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states
hidden_states = self.attn1(norm_hidden_states) + hidden_states
# 2. Cross-Attention
norm_hidden_states = (
@@ -535,14 +497,13 @@ class CrossAttention(nn.Module):
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
tensor = tensor.view(batch_size, seq_len, head_size, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.view(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
@@ -551,23 +512,33 @@ class CrossAttention(nn.Module):
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
context_sequence_length = context.shape[1]
key = self.to_k(context)
value = self.to_v(context)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
query = (
self.reshape_heads_to_batch_dim(query)
.permute(0, 2, 1, 3)
.reshape(batch_size * self.heads, sequence_length, dim // self.heads)
)
value = (
self.reshape_heads_to_batch_dim(value)
.permute(0, 2, 1, 3)
.reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, context_sequence_length)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
@@ -581,9 +552,9 @@ class CrossAttention(nn.Module):
def _attention(self, query, key, value):
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
torch.empty(query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
key,
beta=0,
alpha=self.scale,
)
@@ -606,9 +577,9 @@ class CrossAttention(nn.Module):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
torch.empty(slice_size, query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
query[start_idx:end_idx],
key[start_idx:end_idx].transpose(-1, -2),
key[start_idx:end_idx],
beta=0,
alpha=self.scale,
)
@@ -731,129 +702,3 @@ class AdaLayerNorm(nn.Module):
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
class DualTransformer2DModel(nn.Module):
"""
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
):
super().__init__()
self.transformers = nn.ModuleList(
[
Transformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
num_layers=num_layers,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
sample_size=sample_size,
num_vector_embeds=num_vector_embeds,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
)
for _ in range(2)
]
)
# Variables that can be set by a pipeline:
# The ratio of transformer1 to transformer2's output states to be combined during inference
self.mix_ratio = 0.5
# The shape of `encoder_hidden_states` is expected to be
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
self.condition_lengths = [77, 257]
# Which transformer to use to encode which condition.
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
self.transformer_index_for_condition = [1, 0]
def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
tensor.
"""
input_states = hidden_states
encoded_states = []
tokens_start = 0
for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
transformer_index = self.transformer_index_for_condition[i]
encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
0
]
encoded_states.append(encoded_state - input_states)
tokens_start += self.condition_lengths[i]
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
output_states = output_states + input_states
if not return_dict:
return (output_states,)
return Transformer2DModelOutput(sample=output_states)
def _set_attention_slice(self, slice_size):
for transformer in self.transformers:
transformer._set_attention_slice(slice_size)
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for transformer in self.transformers:
transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
+23 -52
View File
@@ -104,8 +104,6 @@ class FlaxBasicTransformerBlock(nn.Module):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
only_cross_attention (`bool`, defaults to `False`):
Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -113,11 +111,10 @@ class FlaxBasicTransformerBlock(nn.Module):
n_heads: int
d_head: int
dropout: float = 0.0
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
# self attention
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
@@ -129,10 +126,7 @@ class FlaxBasicTransformerBlock(nn.Module):
def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
if self.only_cross_attention:
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
else:
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual
# cross attention
@@ -165,8 +159,6 @@ class FlaxTransformer2DModel(nn.Module):
Number of transformers block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_linear_projection (`bool`, defaults to `False`): tbd
only_cross_attention (`bool`, defaults to `False`): tbd
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
@@ -175,70 +167,49 @@ class FlaxTransformer2DModel(nn.Module):
d_head: int
depth: int = 1
dropout: float = 0.0
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
if self.use_linear_projection:
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
else:
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.transformer_blocks = [
FlaxBasicTransformerBlock(
inner_dim,
self.n_heads,
self.d_head,
dropout=self.dropout,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
for _ in range(self.depth)
]
if self.use_linear_projection:
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
else:
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height * width, channels)
hidden_states = self.proj_in(hidden_states)
else:
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)
for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channels)
else:
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
+6 -3
View File
@@ -49,11 +49,14 @@ def get_timestep_embedding(
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
sin = torch.sin(emb)
cos = torch.cos(emb)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
emb = torch.cat([cos, sin], dim=-1)
else:
emb = torch.cat([sin, cos], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
@@ -126,7 +129,7 @@ class GaussianFourierProjection(nn.Module):
if self.log:
x = torch.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
x_proj = x[:, None] * self.weight[None, :] * (2 * np.pi)
if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
+3 -1
View File
@@ -476,7 +476,9 @@ class ResnetBlock2D(nn.Module):
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
output_tensor = input_tensor + hidden_states
if self.output_scale_factor != 1.0:
output_tensor = output_tensor / self.output_scale_factor
return output_tensor
+7 -10
View File
@@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
implements for all the model (such as downloading or saving, etc.)
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Input sample size.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
@@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
sample_size: Optional[int] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
@@ -175,7 +175,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(
self,
@@ -209,11 +209,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
# 2. pre-process
@@ -247,7 +242,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
sample = self.conv_norm_out(sample)
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
+49 -117
View File
@@ -15,7 +15,7 @@ import numpy as np
import torch
from torch import nn
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
from .attention import AttentionBlock, Transformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
@@ -32,9 +32,6 @@ def get_down_block(
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
@@ -77,9 +74,6 @@ def get_down_block(
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
@@ -143,9 +137,6 @@ def get_up_block(
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
@@ -175,9 +166,6 @@ def get_up_block(
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
@@ -254,6 +242,7 @@ class UNetMidBlock2D(nn.Module):
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
**kwargs,
):
super().__init__()
@@ -333,8 +322,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attention_type="default",
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
**kwargs,
):
super().__init__()
@@ -360,29 +348,16 @@ class UNetMidBlock2DCrossAttn(nn.Module):
attentions = []
for _ in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
)
)
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
@@ -402,17 +377,15 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.resnets = nn.ModuleList(resnets)
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > min(head_dims):
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
for attn in self.attentions:
@@ -532,9 +505,6 @@ class CrossAttnDownBlock2D(nn.Module):
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
@@ -559,30 +529,16 @@ class CrossAttnDownBlock2D(nn.Module):
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
@@ -600,17 +556,15 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > min(head_dims):
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
for attn in self.attentions:
@@ -1100,10 +1054,7 @@ class AttnUpBlock2D(nn.Module):
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
for resnet, attn, res_hidden_states in zip(self.resnets, self.attentions, reversed(res_hidden_states_tuple)):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
@@ -1135,9 +1086,6 @@ class CrossAttnUpBlock2D(nn.Module):
attention_type="default",
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
):
super().__init__()
resnets = []
@@ -1164,30 +1112,16 @@ class CrossAttnUpBlock2D(nn.Module):
pre_norm=resnet_pre_norm,
)
)
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
)
)
else:
attentions.append(
DualTransformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
attentions.append(
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
@@ -1199,17 +1133,15 @@ class CrossAttnUpBlock2D(nn.Module):
self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
if slice_size is not None and slice_size > min(head_dims):
if slice_size is not None and slice_size > self.attn_num_head_channels:
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
)
for attn in self.attentions:
@@ -46,8 +46,6 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
num_layers: int = 1
attn_num_head_channels: int = 1
add_downsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -70,8 +68,6 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
@@ -182,8 +178,6 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
num_layers: int = 1
attn_num_head_channels: int = 1
add_upsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -207,8 +201,6 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
@@ -318,7 +310,6 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32
def setup(self):
@@ -340,7 +331,6 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
dtype=self.dtype,
)
attentions.append(attn_block)
+15 -48
View File
@@ -56,12 +56,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
implements for all the models (such as downloading or saving, etc.)
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
sample_size (`int`, *optional*): The size of the input sample.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
@@ -98,7 +97,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
"DownBlock2D",
),
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
@@ -107,10 +105,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None,
attention_head_dim: int = 8,
):
super().__init__()
@@ -126,20 +121,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
@@ -158,11 +143,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
)
self.down_blocks.append(down_block)
@@ -175,10 +157,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
)
# count how many layers upsample the images
@@ -186,8 +166,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
@@ -215,10 +193,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
attn_num_head_channels=attention_head_dim,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -226,20 +201,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def set_attention_slice(self, slice_size):
head_dims = self.config.attention_head_dim
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
f"Make sure slice_size {slice_size} is a divisor of "
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
)
if slice_size is not None and slice_size > min(head_dims):
if slice_size is not None and slice_size > self.config.attention_head_dim:
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
f"Chunk_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
)
for block in self.down_blocks:
@@ -272,14 +245,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
encoder_hidden_states (`torch.FloatTensor`):
(batch_size, sequence_length, hidden_size) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
@@ -325,12 +298,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
+5 -22
View File
@@ -79,7 +79,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
attention_head_dim (`int`, *optional*, defaults to 8):
The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
@@ -97,13 +97,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"DownBlock2D",
)
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: Union[int, Tuple[int]] = 8
attention_head_dim: int = 8
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32
freq_shift: int = 0
@@ -136,14 +134,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
only_cross_attention = self.only_cross_attention
if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
attention_head_dim = self.attention_head_dim
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
# down
down_blocks = []
output_channel = block_out_channels[0]
@@ -158,10 +148,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=attention_head_dim[i],
attn_num_head_channels=self.attention_head_dim,
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype,
)
else:
@@ -181,16 +169,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
attn_num_head_channels=attention_head_dim[-1],
use_linear_projection=self.use_linear_projection,
attn_num_head_channels=self.attention_head_dim,
dtype=self.dtype,
)
# up
up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
@@ -205,11 +190,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
attn_num_head_channels=reversed_attention_head_dim[i],
attn_num_head_channels=self.attention_head_dim,
add_upsample=not is_final_block,
dropout=self.dropout,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype,
)
else:
+1 -30
View File
@@ -565,7 +565,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x)
@@ -577,7 +576,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
@@ -586,34 +585,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
return DecoderOutput(sample=dec)
def enable_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
+4 -4
View File
@@ -317,8 +317,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
# make sure we don't download PyTorch weights, unless when using from_pt
ignore_patterns = "*.bin" if not from_pt else []
# make sure we don't download PyTorch weights
ignore_patterns = "*.bin"
if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__
@@ -411,13 +411,13 @@ class FlaxDiffusionPipeline(ConfigMixin):
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warning(
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warning(
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
+34 -91
View File
@@ -26,7 +26,7 @@ import torch
import diffusers
import PIL
from huggingface_hub import model_info, snapshot_download
from huggingface_hub import snapshot_download
from packaging import version
from PIL import Image
from tqdm.auto import tqdm
@@ -44,7 +44,6 @@ from .utils import (
BaseOutput,
deprecate,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
is_transformers_available,
logging,
@@ -118,23 +117,6 @@ class AudioPipelineOutput(BaseOutput):
audios: np.ndarray
def is_safetensors_compatible(info) -> bool:
filenames = set(sibling.rfilename for sibling in info.siblings)
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
for pt_filename in pt_filenames:
prefix, raw = os.path.split(pt_filename)
if raw == "pytorch_model.bin":
# transformers specific
sf_filename = os.path.join(prefix, "model.safetensors")
else:
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
if is_safetensors_compatible and sf_filename not in filenames:
logger.warning(f"{sf_filename} not found")
is_safetensors_compatible = False
return is_safetensors_compatible
class DiffusionPipeline(ConfigMixin):
r"""
Base class for all models.
@@ -147,13 +129,10 @@ class DiffusionPipeline(ConfigMixin):
Class attributes:
- **config_name** (`str`) -- name of the config file that will store the class and module names of all
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
components of the diffusion pipeline.
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
passed for the pipeline to function (should be overridden by subclasses).
"""
config_name = "model_index.json"
_optional_components = []
def register_modules(self, **kwargs):
# import it here to avoid circular import
@@ -205,19 +184,12 @@ class DiffusionPipeline(ConfigMixin):
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module", None)
expected_modules, optional_kwargs = self._get_signature_keys(self)
def is_saveable_module(name, value):
if name not in expected_modules:
return False
if name in self._optional_components and value[0] is None:
return False
return True
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue
model_cls = sub_model.__class__
save_method_name = None
@@ -433,7 +405,7 @@ class DiffusionPipeline(ConfigMixin):
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
logger.warn(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
@@ -477,7 +449,7 @@ class DiffusionPipeline(ConfigMixin):
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
# make sure we don't download flax weights
ignore_patterns = ["*.msgpack"]
ignore_patterns = "*.msgpack"
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
@@ -491,15 +463,6 @@ class DiffusionPipeline(ConfigMixin):
user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)
if is_safetensors_available():
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
if is_safetensors_compatible(info):
ignore_patterns.append("*.bin")
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
@@ -560,47 +523,38 @@ class DiffusionPipeline(ConfigMixin):
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
# define init kwargs
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
if len(unused_kwargs) > 0:
logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
)
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
init_kwargs = {}
# import it here to avoid circular import
from diffusers import pipelines
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module:
if not is_pipeline_module and passed_class_obj[name] is not None:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
@@ -616,8 +570,14 @@ class DiffusionPipeline(ConfigMixin):
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
elif passed_class_obj[name] is None:
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else:
logger.warning(
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
@@ -637,7 +597,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
if loaded_sub_model is None:
if loaded_sub_model is None and sub_model_should_be_defined:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
@@ -691,13 +651,11 @@ class DiffusionPipeline(ConfigMixin):
# 4. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
init_kwargs[module] = passed_class_obj[module]
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys()))
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
@@ -706,14 +664,6 @@ class DiffusionPipeline(ConfigMixin):
model = pipeline_class(**init_kwargs)
return model
@staticmethod
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - set(["self"])
return expected_modules, optional_parameters
@property
def components(self) -> Dict[str, Any]:
r"""
@@ -738,10 +688,8 @@ class DiffusionPipeline(ConfigMixin):
Returns:
A dictionaly containing all the modules needed to initialize the pipeline.
"""
expected_modules, optional_parameters = self._get_signature_keys(self)
components = {
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
if set(components.keys()) != expected_modules:
raise ValueError(
@@ -767,7 +715,7 @@ class DiffusionPipeline(ConfigMixin):
return pil_images
def progress_bar(self, iterable=None, total=None):
def progress_bar(self, iterable):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
@@ -775,12 +723,7 @@ class DiffusionPipeline(ConfigMixin):
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
return tqdm(iterable, **self._progress_bar_config)
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
-9
View File
@@ -19,19 +19,10 @@ if is_torch_available() and is_transformers_available():
from .latent_diffusion import LDMTextToImagePipeline
from .stable_diffusion import (
CycleDiffusionPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import torch
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
@@ -68,7 +67,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -86,7 +84,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -117,8 +114,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -127,33 +124,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -163,8 +133,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_xformers_memory_efficient_attention(self):
r"""
@@ -198,14 +166,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
@@ -216,22 +179,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
@@ -245,15 +192,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
def _execution_device(self):
r"""
@@ -428,7 +370,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -448,8 +390,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -461,6 +403,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -468,9 +411,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -516,9 +459,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
@@ -557,29 +497,25 @@ class AltDiffusionPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
@@ -20,7 +20,6 @@ import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
@@ -81,7 +80,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -99,7 +97,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -130,8 +127,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -140,33 +137,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -176,8 +146,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
@@ -193,14 +161,9 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
@@ -224,15 +187,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
def _execution_device(self):
r"""
@@ -433,7 +391,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
return timesteps
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
@@ -484,6 +442,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -562,7 +521,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
@@ -574,29 +533,25 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)
@@ -89,11 +89,7 @@ class DDIMPipeline(DiffusionPipeline):
generator = None
# Sample gaussian noise to begin loop
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
@@ -70,14 +70,14 @@ class DDPMPipeline(DiffusionPipeline):
generated images.
"""
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
new_config = dict(self.scheduler.config)
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
new_config["predict_epsilon"] = predict_epsilon
self.scheduler._internal_dict = FrozenDict(new_config)
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
@@ -94,11 +94,7 @@ class DDPMPipeline(DiffusionPipeline):
generator = None
# Sample gaussian noise to begin loop
if isinstance(self.unet.sample_size, int):
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
else:
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
@@ -114,7 +110,9 @@ class DDPMPipeline(DiffusionPipeline):
model_output = self.unet(image, t).sample
# 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
image = self.scheduler.step(
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
@@ -60,14 +60,13 @@ class LDMTextToImagePipeline(DiffusionPipeline):
):
super().__init__()
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = 256,
width: Optional[int] = 256,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 1.0,
eta: Optional[float] = 0.0,
@@ -80,9 +79,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
height (`int`, *optional*, defaults to 256):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to 256):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -107,9 +106,6 @@ class LDMTextToImagePipeline(DiffusionPipeline):
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(prompt, str):
batch_size = 1
@@ -6,14 +6,7 @@ import numpy as np
import PIL
from PIL import Image
from ...utils import (
BaseOutput,
is_flax_available,
is_onnx_available,
is_torch_available,
is_transformers_available,
is_transformers_version,
)
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
@dataclass
@@ -40,14 +33,8 @@ if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .safety_checker import StableDiffusionSafetyChecker
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
else:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
if is_transformers_available() and is_onnx_available():
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
@@ -63,14 +50,15 @@ if is_transformers_available() and is_flax_available():
Output class for Stable Diffusion pipelines.
Args:
images (`np.ndarray`)
Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images: np.ndarray
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
@@ -20,7 +20,6 @@ import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -133,7 +132,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -144,7 +142,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -162,8 +159,8 @@ class CycleDiffusionPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -172,32 +169,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -207,7 +178,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
@@ -224,14 +194,9 @@ class CycleDiffusionPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
@@ -257,15 +222,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
@@ -475,7 +435,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
return timesteps
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
@@ -528,6 +488,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -607,7 +568,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
@@ -621,70 +582,66 @@ class CycleDiffusionPipeline(DiffusionPipeline):
generator = extra_step_kwargs.pop("generator", None)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
source_latent_model_input = torch.cat([source_latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
source_latent_model_input = torch.cat([source_latents] * 2)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
# predict the noise residual
concat_latent_model_input = torch.stack(
[
source_latent_model_input[0],
latent_model_input[0],
source_latent_model_input[1],
latent_model_input[1],
],
dim=0,
)
concat_text_embeddings = torch.stack(
[
source_text_embeddings[0],
text_embeddings[0],
source_text_embeddings[1],
text_embeddings[1],
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
).sample
# predict the noise residual
concat_latent_model_input = torch.stack(
[
source_latent_model_input[0],
latent_model_input[0],
source_latent_model_input[1],
latent_model_input[1],
],
dim=0,
)
concat_text_embeddings = torch.stack(
[
source_text_embeddings[0],
text_embeddings[0],
source_text_embeddings[1],
text_embeddings[1],
],
dim=0,
)
concat_noise_pred = self.unet(
concat_latent_model_input, t, encoder_hidden_states=concat_text_embeddings
).sample
# perform guidance
(
source_noise_pred_uncond,
noise_pred_uncond,
source_noise_pred_text,
noise_pred_text,
) = concat_noise_pred.chunk(4, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
source_noise_pred_text - source_noise_pred_uncond
)
# perform guidance
(
source_noise_pred_uncond,
noise_pred_uncond,
source_noise_pred_text,
noise_pred_text,
) = concat_noise_pred.chunk(4, dim=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
source_noise_pred_text - source_noise_pred_uncond
)
# Sample source_latents from the posterior distribution.
prev_source_latents = posterior_sample(
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
)
# Compute noise.
noise = compute_noise(
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
)
source_latents = prev_source_latents
# Sample source_latents from the posterior distribution.
prev_source_latents = posterior_sample(
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
)
# Compute noise.
noise = compute_noise(
self.scheduler, prev_source_latents, source_latents, t, source_noise_pred, **extra_step_kwargs
)
source_latents = prev_source_latents
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, variance_noise=noise, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)
@@ -23,7 +23,6 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from packaging import version
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
@@ -35,7 +34,7 @@ from ...schedulers import (
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ...utils import deprecate, logging
from ...utils import logging
from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -89,7 +88,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self.dtype = dtype
if safety_checker is None:
logger.warning(
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -98,27 +97,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -128,7 +106,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
@@ -183,17 +160,13 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
latents: Optional[jnp.array] = None,
debug: bool = False,
neg_prompt_ids: jnp.array = None,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -215,12 +188,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (
batch_size,
self.unet.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
@@ -281,14 +249,15 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
latents: jnp.array = None,
return_dict: bool = True,
jit: bool = False,
debug: bool = False,
neg_prompt_ids: jnp.array = None,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -296,9 +265,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -316,6 +285,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
@@ -330,10 +302,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if jit:
images = _p_generate(
self,
@@ -379,7 +347,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
images = images.reshape(num_devices, batch_size, height, width, 3)
else:
images = np.asarray(images)
has_nsfw_concept = False
if not return_dict:
@@ -41,8 +41,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
@@ -53,7 +51,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -84,22 +81,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
@@ -110,7 +91,6 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
@@ -205,6 +185,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
if isinstance(prompt, str):
batch_size = 1
@@ -77,8 +77,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
@@ -89,7 +87,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -120,7 +117,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
if safety_checker is None:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
@@ -130,12 +127,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
@@ -146,7 +137,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
@@ -241,6 +231,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -90,8 +90,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker: OnnxRuntimeModel
feature_extractor: CLIPFeatureExtractor
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae_encoder: OnnxRuntimeModel,
@@ -102,7 +100,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
@@ -134,7 +131,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
if safety_checker is None:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
@@ -144,12 +141,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
@@ -160,7 +151,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
@@ -246,8 +236,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt: Union[str, List[str]],
image: PIL.Image.Image,
mask_image: PIL.Image.Image,
height: Optional[int] = 512,
width: Optional[int] = 512,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -259,6 +249,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -321,7 +312,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
@@ -27,11 +27,11 @@ def preprocess(image):
return 2.0 * image - 1.0
def preprocess_mask(mask, scale_factor=8):
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST)
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -67,8 +67,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
@@ -88,7 +86,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: OnnxRuntimeModel,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -119,7 +116,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
if safety_checker is None:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
@@ -129,12 +126,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
@@ -145,7 +136,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
@@ -241,6 +231,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -350,7 +341,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# preprocess mask
if not isinstance(mask_image, np.ndarray):
mask_image = preprocess_mask(mask_image, 8)
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.astype(latents_dtype)
mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)
@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import torch
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -67,7 +66,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -85,7 +83,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -116,8 +113,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -126,33 +123,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -162,8 +132,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_xformers_memory_efficient_attention(self):
r"""
@@ -197,14 +165,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
@@ -215,22 +178,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
@@ -244,15 +191,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
def _execution_device(self):
r"""
@@ -427,7 +369,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -447,8 +389,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -460,6 +402,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -467,9 +410,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -515,9 +458,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
@@ -556,29 +496,25 @@ class StableDiffusionPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
@@ -1,481 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionImageVariationPipeline(DiffusionPipeline):
r"""
Pipeline to generate variations from an input image using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
image_encoder: CLIPVisionModelWithProjection,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
if safety_checker is None and requires_safety_checker:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
image_encoder=image_encoder,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeddings = self.image_encoder(image).image_embeds
image_embeddings = image_embeddings.unsqueeze(1)
# duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
uncond_embeddings = torch.zeros_like(image_embeddings)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
return image_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, image, height, width, callback_steps):
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
configuration of
[this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
`CLIPFeatureExtractor`
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width, callback_steps)
# 2. Define call parameters
if isinstance(image, PIL.Image.Image):
batch_size = 1
elif isinstance(image, list):
batch_size = len(image)
else:
batch_size = image.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input image
image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
image_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@@ -20,7 +20,6 @@ import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -79,7 +78,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
@@ -98,7 +96,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -129,8 +126,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -139,33 +136,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -175,8 +145,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
@@ -193,14 +161,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
@@ -226,15 +189,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
@@ -442,7 +400,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
return timesteps
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
init_image = init_image.to(device=device, dtype=dtype)
@@ -493,6 +451,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -571,7 +530,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
@@ -583,29 +542,25 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)
@@ -20,7 +20,6 @@ import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -151,7 +150,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
@@ -162,7 +160,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -194,8 +191,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["skip_prk_steps"] = True
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -204,33 +201,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -240,8 +210,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
@@ -258,14 +226,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
@@ -291,15 +254,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
@@ -501,7 +459,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height // 8, width // 8)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -523,9 +481,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
@@ -553,8 +509,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image],
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
height: Optional[int] = None,
width: Optional[int] = None,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -566,6 +522,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -581,9 +538,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -629,9 +586,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs
self.check_inputs(prompt, height, width, callback_steps)
@@ -655,7 +609,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
timesteps_tensor = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
@@ -699,32 +653,28 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 10. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 11. Post-processing
image = self.decode_latents(latents)
@@ -20,7 +20,6 @@ import torch
import PIL
from diffusers.utils import is_accelerate_available
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
@@ -52,11 +51,11 @@ def preprocess_image(image):
return 2.0 * image - 1.0
def preprocess_mask(mask, scale_factor=8):
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -92,7 +91,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
def __init__(
@@ -111,7 +109,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -142,8 +139,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
if safety_checker is None:
logger.warn(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
@@ -152,33 +149,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -188,8 +158,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
@@ -206,14 +174,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
@@ -239,15 +202,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
@@ -457,7 +415,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
return timesteps
def prepare_latents(self, init_image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
init_image = init_image.to(device=self.device, dtype=dtype)
@@ -492,6 +450,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -573,11 +532,11 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
init_image = preprocess_image(init_image)
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
mask_image = preprocess_mask(mask_image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
timesteps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
@@ -594,33 +553,29 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Post-processing
image = self.decode_latents(latents)
@@ -1,555 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import PIL
from diffusers.utils import is_accelerate_available
from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def preprocess(image):
# resize to multiple of 64
width, height = image.size
width = width - width % 64
height = height - height % 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
return image
class StableDiffusionUpscalePipeline(DiffusionPipeline):
r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
low_res_scheduler ([`SchedulerMixin`]):
A scheduler used to add initial noise to the low res conditioning image. It must be an instance of
[`DDPMScheduler`].
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
low_res_scheduler: DDPMScheduler,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
max_noise_level: int = 350,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
low_res_scheduler=low_res_scheduler,
scheduler=scheduler,
)
self.register_to_config(max_noise_level=max_noise_level)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def check_inputs(self, prompt, image, noise_level, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
)
# verify batch size of prompt and image are same if image is a list or tensor
if isinstance(image, list) or isinstance(image, torch.Tensor):
if isinstance(prompt, str):
batch_size = 1
else:
batch_size = len(prompt)
if isinstance(image, list):
image_batch_size = len(image)
else:
image_batch_size = image.shape[0]
if batch_size != image_batch_size:
raise ValueError(
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
" Please make sure that passed `prompt` matches the batch size of `image`."
)
# check noise level
if noise_level > self.config.max_noise_level:
raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
`Image`, or tensor representing an image batch which will be upscaled. *
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 1. Check inputs
self.check_inputs(prompt, image, noise_level, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Preprocess image
image = [image] if isinstance(image, PIL.Image.Image) else image
if isinstance(image, list):
image = [preprocess(img) for img in image]
image = torch.cat(image, dim=0)
image = image.to(dtype=text_embeddings.dtype, device=device)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Add noise to image
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(image.shape, generator=generator, device="cpu", dtype=text_embeddings.dtype).to(device)
else:
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
image = torch.cat([image] * 2) if do_classifier_free_guidance else image
noise_level = torch.cat([noise_level] * 2) if do_classifier_free_guidance else noise_level
# 6. Prepare latent variables
height, width = image.shape[2:]
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 7. Check that sizes of image and latents match
num_channels_image = image.shape[1]
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, image], dim=1)
# predict the noise residual
noise_pred = self.unet(
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Post-processing
# make sure the VAE is in float32 mode, as it overflows in float16
self.vae.to(dtype=torch.float32)
image = self.decode_latents(latents.float())
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@@ -1,72 +0,0 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_torch_available, is_transformers_available
@dataclass
class SafetyConfig(object):
WEAK = {
"sld_warmup_steps": 15,
"sld_guidance_scale": 20,
"sld_threshold": 0.0,
"sld_momentum_scale": 0.0,
"sld_mom_beta": 0.0,
}
MEDIUM = {
"sld_warmup_steps": 10,
"sld_guidance_scale": 1000,
"sld_threshold": 0.01,
"sld_momentum_scale": 0.3,
"sld_mom_beta": 0.4,
}
STRONG = {
"sld_warmup_steps": 7,
"sld_guidance_scale": 2000,
"sld_threshold": 0.025,
"sld_momentum_scale": 0.5,
"sld_mom_beta": 0.7,
}
MAX = {
"sld_warmup_steps": 0,
"sld_guidance_scale": 5000,
"sld_threshold": 1.0,
"sld_momentum_scale": 0.5,
"sld_mom_beta": 0.7,
}
@dataclass
class StableDiffusionSafePipelineOutput(BaseOutput):
"""
Output class for Safe Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, or `None` if safety checking could not be performed.
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work"
(nsfw) content, or `None` if no safety check was performed or no images were flagged.
applied_safety_concept (`str`)
The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
applied_safety_concept: Optional[str]
if is_transformers_available() and is_torch_available():
from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe
from .safety_checker import SafeStableDiffusionSafetyChecker
@@ -1,764 +0,0 @@
import inspect
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, is_accelerate_available, logging
from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionPipelineSafe(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Safe Latent Diffusion.
The implementation is based on the [`StableDiffusionPipeline`]
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
],
safety_checker: SafeStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__()
safety_concept: Optional[str] = (
"an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity,"
" bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child"
" abuse, brutality, cruelty"
)
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self._safety_text_concept = safety_concept
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@property
def safety_concept(self):
r"""
Getter method for the safety concept used with SLD
Returns:
`str`: The text describing the safety concept
"""
return self._safety_text_concept
@safety_concept.setter
def safety_concept(self, concept):
r"""
Setter method for the safety concept used with SLD
Args:
concept (`str`):
The text of the new safety concept
"""
self._safety_text_concept = concept
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device("cuda")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
enable_safety_guidance,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = text_embeddings[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = uncond_embeddings[0]
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# Encode the safety concept text
if enable_safety_guidance:
safety_concept_input = self.tokenizer(
[self._safety_text_concept],
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0]
# duplicate safety embeddings for each generation per prompt, using mps friendly method
seq_len = safety_embeddings.shape[1]
safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1)
safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance + sld, we need to do three forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing three forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings])
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
flagged_images = None
if any(has_nsfw_concept):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead."
f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} "
)
flagged_images = np.zeros((2, *image.shape[1:]))
for idx, has_nsfw_concept in enumerate(has_nsfw_concept):
if has_nsfw_concept:
flagged_images[idx] = image[idx]
image[idx] = np.zeros(image[idx].shape) # black image
else:
has_nsfw_concept = None
flagged_images = None
return image, has_nsfw_concept, flagged_images
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def perform_safety_guidance(
self,
enable_safety_guidance,
safety_momentum,
noise_guidance,
noise_pred_out,
i,
sld_guidance_scale,
sld_warmup_steps,
sld_threshold,
sld_momentum_scale,
sld_mom_beta,
):
# Perform SLD guidance
if enable_safety_guidance:
if safety_momentum is None:
safety_momentum = torch.zeros_like(noise_guidance)
noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1]
noise_pred_safety_concept = noise_pred_out[2]
# Equation 6
scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0)
# Equation 6
safety_concept_scale = torch.where(
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale
)
# Equation 4
noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale)
# Equation 7
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
# Equation 8
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
if i >= sld_warmup_steps: # Warmup
# Equation 3
noise_guidance = noise_guidance - noise_guidance_safety
return noise_guidance, safety_momentum
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
sld_guidance_scale: Optional[float] = 1000,
sld_warmup_steps: Optional[int] = 10,
sld_threshold: Optional[float] = 0.01,
sld_momentum_scale: Optional[float] = 0.3,
sld_mom_beta: Optional[float] = 0.4,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
sld_guidance_scale (`float`, *optional*, defaults to 1000):
Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
`sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be
disabled.
sld_warmup_steps (`int`, *optional*, defaults to 10):
Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than
`sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent
Diffusion](https://arxiv.org/abs/2211.05105).
sld_threshold (`float`, *optional*, defaults to 0.01):
Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold`
is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
sld_momentum_scale (`float`, *optional*, defaults to 0.3):
Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0
momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller
than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent
Diffusion](https://arxiv.org/abs/2211.05105).
sld_mom_beta (`float`, *optional*, defaults to 0.4):
Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous
momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller
than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent
Diffusion](https://arxiv.org/abs/2211.05105).
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance
if not enable_safety_guidance:
warnings.warn("Safety checker disabled!")
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
safety_momentum = None
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * (3 if enable_safety_guidance else 2))
if do_classifier_free_guidance
else latents
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2))
noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
# default classifier free guidance
noise_guidance = noise_pred_text - noise_pred_uncond
# Perform SLD guidance
if enable_safety_guidance:
if safety_momentum is None:
safety_momentum = torch.zeros_like(noise_guidance)
noise_pred_safety_concept = noise_pred_out[2]
# Equation 6
scale = torch.clamp(
torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0
)
# Equation 6
safety_concept_scale = torch.where(
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold,
torch.zeros_like(scale),
scale,
)
# Equation 4
noise_guidance_safety = torch.mul(
(noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale
)
# Equation 7
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
# Equation 8
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
if i >= sld_warmup_steps: # Warmup
# Equation 3
noise_guidance = noise_guidance - noise_guidance_safety
noise_pred = noise_pred_uncond + guidance_scale * noise_guidance
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept, flagged_images = self.run_safety_checker(
image, device, text_embeddings.dtype, enable_safety_guidance
)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if flagged_images is not None:
flagged_images = self.numpy_to_pil(flagged_images)
if not return_dict:
return (
image,
has_nsfw_concept,
self._safety_text_concept if enable_safety_guidance else None,
flagged_images,
)
return StableDiffusionSafePipelineOutput(
images=image,
nsfw_content_detected=has_nsfw_concept,
applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None,
unsafe_images=flagged_images,
)
@@ -1,110 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
from ...utils import logging
logger = logging.get_logger(__name__)
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
class SafeStableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.vision_model = CLIPVisionModel(config.vision_config)
self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
@torch.no_grad()
def forward(self, clip_input, images):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
return images, has_nsfw_concepts
@torch.no_grad()
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
# increase this value to create a stronger `nsfw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
# special_scores = special_scores.round(decimals=3)
special_care = torch.any(special_scores > 0, dim=1)
special_adjustment = special_care * 0.01
special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
# concept_scores = concept_scores.round(decimals=3)
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
return images, has_nsfw_concepts
@@ -1,16 +0,0 @@
from ...utils import is_torch_available, is_transformers_available, is_transformers_version
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
from .modeling_text_unet import UNetFlatConditionModel
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
else:
from ...utils.dummy_torch_and_transformers_objects import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
File diff suppressed because it is too large Load Diff
@@ -1,463 +0,0 @@
import inspect
from typing import Callable, List, Optional, Union
import torch
import PIL.Image
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class VersatileDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
tokenizer: CLIPTokenizer
image_feature_extractor: CLIPFeatureExtractor
text_encoder: CLIPTextModel
image_encoder: CLIPVisionModel
image_unet: UNet2DConditionModel
text_unet: UNet2DConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
def __init__(
self,
tokenizer: CLIPTokenizer,
image_feature_extractor: CLIPFeatureExtractor,
text_encoder: CLIPTextModel,
image_encoder: CLIPVisionModel,
image_unet: UNet2DConditionModel,
text_unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
image_feature_extractor=image_feature_extractor,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_unet=image_unet,
text_unet=text_unet,
vae=vae,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2
self.image_unet.set_attention_slice(slice_size)
self.text_unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
@torch.no_grad()
def image_variation(
self,
image: Union[torch.FloatTensor, PIL.Image.Image],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
The image prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> from diffusers import VersatileDiffusionPipeline
>>> import torch
>>> import requests
>>> from io import BytesIO
>>> from PIL import Image
>>> # let's download an initial image
>>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
>>> response = requests.get(url)
>>> image = Image.open(BytesIO(response.content)).convert("RGB")
>>> pipe = VersatileDiffusionPipeline.from_pretrained(
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> generator = torch.Generator(device="cuda").manual_seed(0)
>>> image = pipe.image_variation(image, generator=generator).images[0]
>>> image.save("./car_variation.png")
```
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys()
components = {name: component for name, component in self.components.items() if name in expected_components}
return VersatileDiffusionImageVariationPipeline(**components)(
image=image,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
)
@torch.no_grad()
def text_to_image(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> from diffusers import VersatileDiffusionPipeline
>>> import torch
>>> pipe = VersatileDiffusionPipeline.from_pretrained(
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> generator = torch.Generator(device="cuda").manual_seed(0)
>>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0]
>>> image.save("./astronaut.png")
```
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys()
components = {name: component for name, component in self.components.items() if name in expected_components}
temp_pipeline = VersatileDiffusionTextToImagePipeline(**components)
output = temp_pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
)
# swap the attention blocks back to the original state
temp_pipeline._swap_unet_attention_blocks()
return output
@torch.no_grad()
def dual_guided(
self,
prompt: Union[PIL.Image.Image, List[PIL.Image.Image]],
image: Union[str, List[str]],
text_to_image_strength: float = 0.5,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> from diffusers import VersatileDiffusionPipeline
>>> import torch
>>> import requests
>>> from io import BytesIO
>>> from PIL import Image
>>> # let's download an initial image
>>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
>>> response = requests.get(url)
>>> image = Image.open(BytesIO(response.content)).convert("RGB")
>>> text = "a red car in the sun"
>>> pipe = VersatileDiffusionPipeline.from_pretrained(
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> generator = torch.Generator(device="cuda").manual_seed(0)
>>> text_to_image_strength = 0.75
>>> image = pipe.dual_guided(
... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator
... ).images[0]
>>> image.save("./car_variation.png")
```
Returns:
[`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
returning a tuple, the first element is a list with the generated images.
"""
expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys()
components = {name: component for name, component in self.components.items() if name in expected_components}
temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components)
output = temp_pipeline(
prompt=prompt,
image=image,
text_to_image_strength=text_to_image_strength,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
)
temp_pipeline._revert_dual_attention()
return output
@@ -1,641 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
import PIL
from transformers import (
CLIPFeatureExtractor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention import DualTransformer2DModel, Transformer2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
from .modeling_text_unet import UNetFlatConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
vqvae ([`VQModel`]):
Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
bert ([`LDMBertModel`]):
Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
tokenizer (`transformers.BertTokenizer`):
Tokenizer of class
[BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
tokenizer: CLIPTokenizer
image_feature_extractor: CLIPFeatureExtractor
text_encoder: CLIPTextModelWithProjection
image_encoder: CLIPVisionModelWithProjection
image_unet: UNet2DConditionModel
text_unet: UNetFlatConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
_optional_components = ["text_unet"]
def __init__(
self,
tokenizer: CLIPTokenizer,
image_feature_extractor: CLIPFeatureExtractor,
text_encoder: CLIPTextModelWithProjection,
image_encoder: CLIPVisionModelWithProjection,
image_unet: UNet2DConditionModel,
text_unet: UNetFlatConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
image_feature_extractor=image_feature_extractor,
text_encoder=text_encoder,
image_encoder=image_encoder,
image_unet=image_unet,
text_unet=text_unet,
vae=vae,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
if self.text_unet is not None and (
"dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention
):
# if loading from a universal checkpoint rather than a saved dual-guided pipeline
self._convert_to_dual_attention()
def remove_unused_weights(self):
self.register_modules(text_unet=None)
def _convert_to_dual_attention(self):
"""
Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks
from both `image_unet` and `text_unet`
"""
for name, module in self.image_unet.named_modules():
if isinstance(module, Transformer2DModel):
parent_name, index = name.rsplit(".", 1)
index = int(index)
image_transformer = self.image_unet.get_submodule(parent_name)[index]
text_transformer = self.text_unet.get_submodule(parent_name)[index]
config = image_transformer.config
dual_transformer = DualTransformer2DModel(
num_attention_heads=config.num_attention_heads,
attention_head_dim=config.attention_head_dim,
in_channels=config.in_channels,
num_layers=config.num_layers,
dropout=config.dropout,
norm_num_groups=config.norm_num_groups,
cross_attention_dim=config.cross_attention_dim,
attention_bias=config.attention_bias,
sample_size=config.sample_size,
num_vector_embeds=config.num_vector_embeds,
activation_fn=config.activation_fn,
num_embeds_ada_norm=config.num_embeds_ada_norm,
)
dual_transformer.transformers[0] = image_transformer
dual_transformer.transformers[1] = text_transformer
self.image_unet.get_submodule(parent_name)[index] = dual_transformer
self.image_unet.register_to_config(dual_cross_attention=True)
def _revert_dual_attention(self):
"""
Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call
this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline`
"""
for name, module in self.image_unet.named_modules():
if isinstance(module, DualTransformer2DModel):
parent_name, index = name.rsplit(".", 1)
index = int(index)
self.image_unet.get_submodule(parent_name)[index] = module.transformers[0]
self.image_unet.register_to_config(dual_cross_attention=False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.image_unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.image_unet.config.attention_head_dim)
self.image_unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
return self.device
for module in self.image_unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_text_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
"""
def normalize_embeddings(encoder_output):
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
embeds_pooled = encoder_output.text_embeds
embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
return embeds
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = normalize_embeddings(text_embeddings)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens = [""] * batch_size
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
def _encode_image_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
"""
def normalize_embeddings(encoder_output):
embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state)
embeds = self.image_encoder.visual_projection(embeds)
embeds_pooled = embeds[:, 0:1]
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
return embeds
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
image_input = self.image_feature_extractor(images=prompt, return_tensors="pt")
pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype)
image_embeddings = self.image_encoder(pixel_values)
image_embeddings = normalize_embeddings(image_embeddings)
# duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
uncond_embeddings = self.image_encoder(pixel_values)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and conditional embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
return image_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, prompt, image, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}")
if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list):
raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")):
for name, module in self.image_unet.named_modules():
if isinstance(module, DualTransformer2DModel):
module.mix_ratio = mix_ratio
for i, type in enumerate(condition_types):
if type == "text":
module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings
module.transformer_index_for_condition[i] = 1 # use the second (text) transformer
else:
module.condition_lengths[i] = 257
module.transformer_index_for_condition[i] = 0 # use the first (image) transformer
@torch.no_grad()
def __call__(
self,
prompt: Union[PIL.Image.Image, List[PIL.Image.Image]],
image: Union[str, List[str]],
text_to_image_strength: float = 0.5,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> from diffusers import VersatileDiffusionDualGuidedPipeline
>>> import torch
>>> import requests
>>> from io import BytesIO
>>> from PIL import Image
>>> # let's download an initial image
>>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
>>> response = requests.get(url)
>>> image = Image.open(BytesIO(response.content)).convert("RGB")
>>> text = "a red car in the sun"
>>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
... )
>>> pipe.remove_unused_weights()
>>> pipe = pipe.to("cuda")
>>> generator = torch.Generator(device="cuda").manual_seed(0)
>>> text_to_image_strength = 0.75
>>> image = pipe(
... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator
... ).images[0]
>>> image.save("./car_variation.png")
```
Returns:
[`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
returning a tuple, the first element is a list with the generated images.
"""
# 0. Default height and width to unet
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, image, height, width, callback_steps)
# 2. Define call parameters
prompt = [prompt] if not isinstance(prompt, list) else prompt
image = [image] if not isinstance(image, list) else image
batch_size = len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompts
text_embeddings = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance)
dual_prompt_embeddings = torch.cat([text_embeddings, image_embeddings], dim=1)
prompt_types = ("text", "image")
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.image_unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
dual_prompt_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Combine the attention blocks of the image and text UNets
self.set_transformer_params(text_to_image_strength, prompt_types)
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=dual_prompt_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@@ -1,471 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import torch.utils.checkpoint
import PIL
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
vqvae ([`VQModel`]):
Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
bert ([`LDMBertModel`]):
Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
tokenizer (`transformers.BertTokenizer`):
Tokenizer of class
[BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
image_feature_extractor: CLIPFeatureExtractor
image_encoder: CLIPVisionModelWithProjection
image_unet: UNet2DConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
def __init__(
self,
image_feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection,
image_unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
self.register_modules(
image_feature_extractor=image_feature_extractor,
image_encoder=image_encoder,
image_unet=image_unet,
vae=vae,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.image_unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.image_unet.config.attention_head_dim)
self.image_unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
return self.device
for module in self.image_unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
def normalize_embeddings(encoder_output):
embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state)
embeds = self.image_encoder.visual_projection(embeds)
embeds_pooled = embeds[:, 0:1]
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
return embeds
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
image_input = self.image_feature_extractor(images=prompt, return_tensors="pt")
pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype)
image_embeddings = self.image_encoder(pixel_values)
image_embeddings = normalize_embeddings(image_embeddings)
# duplicate image embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = image_embeddings.shape
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_images: List[str]
if negative_prompt is None:
uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, PIL.Image.Image):
uncond_images = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_images = negative_prompt
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
uncond_embeddings = self.image_encoder(pixel_values)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and conditional embeddings into a single batch
# to avoid doing two forward passes
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
return image_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, image, height, width, callback_steps):
if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor):
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
The image prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> from diffusers import VersatileDiffusionImageVariationPipeline
>>> import torch
>>> import requests
>>> from io import BytesIO
>>> from PIL import Image
>>> # let's download an initial image
>>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
>>> response = requests.get(url)
>>> image = Image.open(BytesIO(response.content)).convert("RGB")
>>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained(
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> generator = torch.Generator(device="cuda").manual_seed(0)
>>> image = pipe(image, generator=generator).images[0]
>>> image.save("./car_variation.png")
```
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(image, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
image_embeddings = self._encode_prompt(
image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.image_unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
image_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@@ -1,525 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, List, Optional, Union
import torch
import torch.utils.checkpoint
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention import Transformer2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import is_accelerate_available, logging
from .modeling_text_unet import UNetFlatConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
vqvae ([`VQModel`]):
Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
bert ([`LDMBertModel`]):
Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
tokenizer (`transformers.BertTokenizer`):
Tokenizer of class
[BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
tokenizer: CLIPTokenizer
image_feature_extractor: CLIPFeatureExtractor
text_encoder: CLIPTextModelWithProjection
image_unet: UNet2DConditionModel
text_unet: UNetFlatConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
_optional_components = ["text_unet"]
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModelWithProjection,
image_unet: UNet2DConditionModel,
text_unet: UNetFlatConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
image_unet=image_unet,
text_unet=text_unet,
vae=vae,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
if self.text_unet is not None:
self._swap_unet_attention_blocks()
def _swap_unet_attention_blocks(self):
"""
Swap the `Transformer2DModel` blocks between the image and text UNets
"""
for name, module in self.image_unet.named_modules():
if isinstance(module, Transformer2DModel):
parent_name, index = name.rsplit(".", 1)
index = int(index)
self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = (
self.text_unet.get_submodule(parent_name)[index],
self.image_unet.get_submodule(parent_name)[index],
)
def remove_unused_weights(self):
self.register_modules(text_unet=None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.image_unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.image_unet.config.attention_head_dim)
self.image_unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
return self.device
for module in self.image_unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
"""
def normalize_embeddings(encoder_output):
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
embeds_pooled = encoder_output.text_embeds
embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
return embeds
batch_size = len(prompt) if isinstance(prompt, list) else 1
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
text_embeddings = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
text_embeddings = normalize_embeddings(text_embeddings)
# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
uncond_embeddings = normalize_embeddings(uncond_embeddings)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
return text_embeddings
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
```py
>>> from diffusers import VersatileDiffusionTextToImagePipeline
>>> import torch
>>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained(
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
... )
>>> pipe.remove_unused_weights()
>>> pipe = pipe.to("cuda")
>>> generator = torch.Generator(device="cuda").manual_seed(0)
>>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0]
>>> image.save("./astronaut.png")
```
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.image_unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
-3
View File
@@ -22,10 +22,7 @@ if is_torch_available():
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_heun_discrete import HeunDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler
+5 -52
View File
@@ -17,13 +17,13 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin
@@ -106,15 +106,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
order = 1
@register_to_config
def __init__(
@@ -123,23 +118,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
trained_betas: Optional[np.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
@@ -273,19 +258,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
if self.config.clip_sample:
@@ -356,25 +329,5 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
return self.config.num_train_timesteps
@@ -23,7 +23,6 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
@@ -109,14 +108,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
@property
def has_state(self):
@@ -131,17 +125,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
@@ -275,19 +259,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
+13 -49
View File
@@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -99,14 +99,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
order = 1
@register_to_config
def __init__(
@@ -115,22 +113,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
**kwargs,
predict_epsilon: bool = True,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
@@ -252,13 +241,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"""
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
new_config = dict(self.config)
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
new_config["predict_epsilon"] = predict_epsilon
self._internal_dict = FrozenDict(new_config)
t = timestep
@@ -276,15 +265,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.prediction_type == "epsilon":
if self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the DDPMScheduler."
)
pred_original_sample = model_output
# 3. Clip "predicted x_0"
if self.config.clip_sample:
@@ -346,25 +330,5 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def get_velocity(
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
timesteps = timesteps.to(sample.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(sample.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
return self.config.num_train_timesteps
@@ -103,13 +103,12 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
@property
def has_state(self):
@@ -125,17 +124,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
trained_betas: Optional[jnp.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
**kwargs,
predict_epsilon: bool = True,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
@@ -214,6 +204,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
predict_epsilon: bool = True,
return_dict: bool = True,
**kwargs,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
@@ -236,13 +227,13 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
new_config = dict(self.config)
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
new_config["predict_epsilon"] = predict_epsilon
self._internal_dict = FrozenDict(new_config)
t = timestep
@@ -260,15 +251,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.prediction_type == "epsilon":
if self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the FlaxDDPMScheduler."
)
pred_original_sample = model_output
# 3. Clip "predicted x_0"
if self.config.clip_sample:
@@ -21,7 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -87,9 +87,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_order (`int`, default `2`):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
or `v-prediction`.
predict_epsilon (`bool`, default `True`):
we currently support both the noise prediction model and the data prediction model. If the model predicts
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
`predict_epsilon` to `False`.
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
@@ -117,8 +118,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
order = 1
@register_to_config
def __init__(
@@ -127,27 +126,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
trained_betas: Optional[np.ndarray] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
predict_epsilon: bool = True,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
@@ -213,7 +203,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
discretize an integral of the data prediction model. So we need to first convert the model output to the
corresponding type to match the algorithm.
@@ -231,25 +221,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon":
if self.config.predict_epsilon:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverMultistepScheduler."
)
x0_pred = model_output
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile(
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
@@ -258,25 +236,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon":
if self.config.predict_epsilon:
return model_output
elif self.config.prediction_type == "sample":
else:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the DPMSolverMultistepScheduler."
)
def dpm_solver_first_order_update(
self,
@@ -23,7 +23,6 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
@@ -119,9 +118,10 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
solver_order (`int`, default `2`):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
or `v-prediction`.
predict_epsilon (`bool`, default `True`):
we currently support both the noise prediction model and the data prediction model. If the model predicts
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
`predict_epsilon` to `False`.
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
@@ -149,7 +149,6 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]
@property
def has_state(self):
@@ -164,23 +163,14 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
predict_epsilon: bool = True,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = jnp.asarray(trained_betas)
elif beta_schedule == "linear":
@@ -252,7 +242,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
discretize an integral of the data prediction model. So we need to first convert the model output to the
corresponding type to match the algorithm.
@@ -270,20 +260,11 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.prediction_type == "epsilon":
if self.config.predict_epsilon:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = alpha_t * sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
" or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
)
x0_pred = model_output
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = jnp.percentile(
@@ -296,21 +277,12 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":
if self.config.prediction_type == "epsilon":
if self.config.predict_epsilon:
return model_output
elif self.config.prediction_type == "sample":
else:
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
" or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
)
def dpm_solver_first_order_update(
self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray
@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -68,7 +68,6 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
order = 1
@register_to_config
def __init__(
@@ -77,10 +76,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
trained_betas: Optional[np.ndarray] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
@@ -190,7 +189,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
)
if not self.is_scale_input_called:
logger.warning(
logger.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -69,7 +69,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
order = 1
@register_to_config
def __init__(
@@ -78,11 +77,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
trained_betas: Optional[np.ndarray] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
@@ -200,7 +198,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
)
if not self.is_scale_input_called:
logger.warning(
logger.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
@@ -231,15 +229,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
pred_original_sample = sample - sigma_hat * model_output
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
@@ -1,249 +0,0 @@
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original
k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
order = 2
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
def index_for_timestep(self, timestep):
indices = (self.timesteps == timestep).nonzero()
if self.state_in_first_order:
pos = -1
else:
pos = 0
return indices[pos].item()
def scale_model_input(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor:
"""
Args:
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
step_index = self.index_for_timestep(timestep)
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = timesteps.to(device, dtype=torch.float32)
else:
self.timesteps = timesteps.to(device=device)
# empty dt and derivative
self.prev_derivative = None
self.dt = None
@property
def state_in_first_order(self):
return self.dt is None
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: Union[float, torch.FloatTensor],
sample: Union[torch.FloatTensor, np.ndarray],
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Args:
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
step_index = self.index_for_timestep(timestep)
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
else:
# 2nd order / Heun's method
sigma = self.sigmas[step_index - 1]
sigma_next = self.sigmas[step_index]
# currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before
# passing it to the model which requires a change in API
gamma = 0
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma_hat * model_output
if self.state_in_first_order:
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
# 3. 1st order derivative
dt = sigma_next - sigma_hat
# store for 2nd order step
self.prev_derivative = derivative
self.dt = dt
self.sample = sample
else:
# 2. 2nd order / Heun's method
derivative = (sample - pred_original_sample) / sigma_hat
derivative = (self.prev_derivative + derivative) / 2
# 3. Retrieve 1st order derivative
dt = self.dt
sample = self.sample
# free dt and derivative
# Note, this puts the scheduler in "first order mode"
self.prev_derivative = None
self.dt = None
self.sample = None
prev_sample = sample + derivative * dt
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [self.index_for_timestep(t) for t in timesteps]
sigma = self.sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
+3 -12
View File
@@ -13,9 +13,8 @@
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
from typing import Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
@@ -38,12 +37,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps (`int`): number of diffusion steps used to train the model.
"""
order = 1
@register_to_config
def __init__(
self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
):
def __init__(self, num_train_timesteps: int = 1000):
# set `betas`, `alphas`, `timesteps`
self.set_timesteps(num_train_timesteps)
@@ -70,11 +65,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
steps = torch.cat([steps, torch.tensor([0.0])])
if self.config.trained_betas is not None:
self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
else:
self.betas = torch.sin(steps * math.pi / 2) ** 2
self.betas = torch.sin(steps * math.pi / 2) ** 2
self.alphas = (1.0 - self.betas**2) ** 0.5
timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
@@ -1,268 +0,0 @@
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188
Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022).
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
order = 2
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
def index_for_timestep(self, timestep):
indices = (self.timesteps == timestep).nonzero()
if self.state_in_first_order:
pos = -1
else:
pos = 0
return indices[pos].item()
def scale_model_input(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor:
"""
Args:
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
step_index = self.index_for_timestep(timestep)
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
# compute up and down sigmas
sigmas_next = sigmas.roll(-1)
sigmas_next[-1] = 0.0
sigmas_up = (sigmas_next**2 * (sigmas**2 - sigmas_next**2) / sigmas**2) ** 0.5
sigmas_down = (sigmas_next**2 - sigmas_up**2) ** 0.5
sigmas_down[-1] = 0.0
self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
self.sigmas_up = torch.cat([sigmas_up[:1], sigmas_up[1:].repeat_interleave(2), sigmas_up[-1:]])
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
timesteps = torch.from_numpy(timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = timesteps.to(device, dtype=torch.float32)
else:
self.timesteps = timesteps
self.sample = None
@property
def state_in_first_order(self):
return self.sample is None
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: Union[float, torch.FloatTensor],
sample: Union[torch.FloatTensor, np.ndarray],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Args:
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
step_index = self.index_for_timestep(timestep)
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
else:
# 2nd order / KPDM2's method
sigma = self.sigmas[step_index - 1]
sigma_next = self.sigmas[step_index]
sigma_up = self.sigmas_up[step_index - 1]
sigma_down = self.sigmas_down[step_index - 1]
# currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before
# passing it to the model which requires a change in API
gamma = 0
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
device
)
else:
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(
device
)
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma_hat * model_output
if self.state_in_first_order:
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
# 3. 1st order derivative
dt = sigma_next - sigma_hat
# store for 2nd order step
self.sample = sample
self.dt = dt
prev_sample = sample + derivative * dt
else:
# DPM-Solver-2
derivative = (sample - pred_original_sample) / sigma_hat
dt = sigma_down - sigma_hat
sample = self.sample
self.sample = None
prev_sample = sample + derivative * dt
prev_sample = prev_sample + noise * sigma_up
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [self.index_for_timestep(t) for t in timesteps]
sigma = self.sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
@@ -1,283 +0,0 @@
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see:
https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188
Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022).
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
order = 2
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
def index_for_timestep(self, timestep):
indices = (self.timesteps == timestep).nonzero()
if self.state_in_first_order:
pos = -1
else:
pos = 0
return indices[pos].item()
def scale_model_input(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor:
"""
Args:
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
step_index = self.index_for_timestep(timestep)
if self.state_in_first_order:
sigma = self.sigmas[step_index]
else:
sigma = self.sigmas_interpol[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
# interpolate sigmas
sigmas_interpol = sigmas.log().lerp(sigmas.roll(1).log(), 0.5).exp()
self.sigmas = torch.cat([sigmas[:1], sigmas[1:].repeat_interleave(2), sigmas[-1:]])
self.sigmas_interpol = torch.cat(
[sigmas_interpol[:1], sigmas_interpol[1:].repeat_interleave(2), sigmas_interpol[-1:]]
)
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
timesteps = torch.from_numpy(timesteps).to(device)
# interpolate timesteps
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = timesteps.to(torch.float32)
else:
self.timesteps = timesteps
self.sample = None
def sigma_to_t(self, sigma):
# get log sigma
log_sigma = sigma.log()
# get distribution
dists = log_sigma - self.log_sigmas[:, None]
# get sigmas range
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = self.log_sigmas[low_idx]
high = self.log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
return t
@property
def state_in_first_order(self):
return self.sample is None
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: Union[float, torch.FloatTensor],
sample: Union[torch.FloatTensor, np.ndarray],
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Args:
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
current instance of sample being created by diffusion process.
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
step_index = self.index_for_timestep(timestep)
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_interpol = self.sigmas_interpol[step_index + 1]
sigma_next = self.sigmas[step_index + 1]
else:
# 2nd order / KDPM2's method
sigma = self.sigmas[step_index - 1]
sigma_interpol = self.sigmas_interpol[step_index]
sigma_next = self.sigmas[step_index]
# currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before
# passing it to the model which requires a change in API
gamma = 0
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.state_in_first_order:
pred_original_sample = sample - sigma_hat * model_output
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
# 3. 1st order derivative
dt = sigma_interpol - sigma_hat
# store for 2nd order step
self.sample = sample
else:
# DPM-Solver-2
pred_original_sample = sample - sigma_interpol * model_output
derivative = (sample - pred_original_sample) / sigma_interpol
dt = sigma_next - sigma_hat
sample = self.sample
self.sample = None
prev_sample = sample + derivative * dt
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [self.index_for_timestep(t) for t in timesteps]
sigma = self.sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
@@ -77,8 +77,6 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
order = 2
@register_to_config
def __init__(
self,
@@ -13,7 +13,7 @@
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -68,7 +68,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
order = 1
@register_to_config
def __init__(
@@ -77,10 +76,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
trained_betas: Optional[np.ndarray] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
+3 -4
View File
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
@@ -90,7 +90,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
order = 1
@register_to_config
def __init__(
@@ -99,13 +98,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
trained_betas: Optional[np.ndarray] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
@@ -102,8 +102,6 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
"""
order = 1
@register_to_config
def __init__(
self,
@@ -66,8 +66,6 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
correct_steps (`int`): number of correction steps performed on a produced sample.
"""
order = 1
@register_to_config
def __init__(
self,

Some files were not shown because too many files have changed in this diff Show More