Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b02323a60 | |||
| 462a79d39a | |||
| 6883294d44 | |||
| b9e921feea | |||
| 7684518377 | |||
| 520bb082be | |||
| 9ec5084a9c | |||
| 02aa4ef12e | |||
| 8faa822ddc | |||
| 86aa747da9 | |||
| d52388f486 | |||
| babfb8a020 | |||
| 35099b207e | |||
| 2c6bc0f13b | |||
| 2902109061 | |||
| f26cde3dff | |||
| 9f10c545cb | |||
| 5c10e68a1f | |||
| d50e321745 | |||
| 8e2c4cd56c | |||
| bb2c64a08c | |||
| 05a36d5c1a | |||
| cbfed0c256 | |||
| e0e86b7470 | |||
| 81d8f4a9e1 | |||
| cecdd8bdd1 | |||
| 30f6f44104 | |||
| 9f476388fa | |||
| 9479052dde | |||
| 35d8186172 | |||
| 1524122532 | |||
| f07a16e09b | |||
| 16a32c9dab | |||
| 2625fb59dc | |||
| 0eb507f2af | |||
| 9e234d8048 | |||
| 8fd3a74322 | |||
| 44e56de9aa | |||
| 2d6d4edbbd | |||
| 8b84f85192 | |||
| e50c25d808 | |||
| 182eb959e5 | |||
| ad93593345 |
@@ -60,6 +60,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -127,6 +128,7 @@ jobs:
|
||||
${CONDA_RUN} python -m pip install -e .[quality,test]
|
||||
${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu
|
||||
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
|
||||
${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
shell: arch -arch arm64 bash {0}
|
||||
|
||||
@@ -62,6 +62,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -131,6 +132,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install -e .[quality,test,training]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
@@ -106,6 +106,10 @@
|
||||
title: "Score SDE VE"
|
||||
- local: api/pipelines/stable_diffusion
|
||||
title: "Stable Diffusion"
|
||||
- local: api/pipelines/stable_diffusion_2
|
||||
title: "Stable Diffusion 2"
|
||||
- local: api/pipelines/stable_diffusion_safe
|
||||
title: "Safe Stable Diffusion"
|
||||
- local: api/pipelines/stochastic_karras_ve
|
||||
title: "Stochastic Karras VE"
|
||||
- local: api/pipelines/dance_diffusion
|
||||
|
||||
@@ -51,7 +51,7 @@ To use a different scheduler, you can either change it via the [`ConfigMixin.fro
|
||||
```
|
||||
|
||||
|
||||
- *How to conver all use cases with multiple or single pipeline*
|
||||
- *How to convert all use cases with multiple or single pipeline*
|
||||
|
||||
If you want to use all possible use cases in a single `DiffusionPipeline` we recommend using the `components` functionality to instantiate all components in the most memory-efficient way:
|
||||
|
||||
|
||||
@@ -58,7 +58,14 @@ available a colab notebook to directly try them out.
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
|
||||
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ To use a different scheduler, you can either change it via the [`ConfigMixin.fro
|
||||
```
|
||||
|
||||
|
||||
### How to conver all use cases with multiple or single pipeline
|
||||
### How to convert all use cases with multiple or single pipeline
|
||||
|
||||
If you want to use all possible use cases in a single `DiffusionPipeline` you can either:
|
||||
- Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or
|
||||
@@ -88,3 +88,17 @@ If you want to use all possible use cases in a single `DiffusionPipeline` you ca
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
|
||||
## StableDiffusionImageVariationPipeline
|
||||
[[autodoc]] StableDiffusionImageVariationPipeline
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
|
||||
## StableDiffusionUpscalePipeline
|
||||
[[autodoc]] StableDiffusionUpscalePipeline
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Stable diffusion 2
|
||||
|
||||
Stable Diffusion 2 is a text-to-image _latent diffusion_ model built upon the work of [Stable Diffusion 1](https://stability.ai/blog/stable-diffusion-public-release).
|
||||
The project to train Stable Diffusion 2 was led by Robin Rombach and Katherine Crowson from [Stability AI](https://stability.ai/) and [LAION](https://laion.ai/).
|
||||
|
||||
*The Stable Diffusion 2.0 release includes robust text-to-image models trained using a brand new text encoder (OpenCLIP), developed by LAION with support from Stability AI, which greatly improves the quality of the generated images compared to earlier V1 releases. The text-to-image models in this release can generate images with default resolutions of both 512x512 pixels and 768x768 pixels.
|
||||
These models are trained on an aesthetic subset of the [LAION-5B dataset](https://laion.ai/blog/laion-5b/) created by the DeepFloyd team at Stability AI, which is then further filtered to remove adult content using [LAION’s NSFW filter](https://openreview.net/forum?id=M3Y74vmsMcY).*
|
||||
|
||||
For more details about how Stable Diffusion 2 works and how it differs from Stable Diffusion 1, please refer to the official [launch announcement post](https://stability.ai/blog/stable-diffusion-v2-release).
|
||||
|
||||
## Tips
|
||||
|
||||
### Available checkpoints:
|
||||
|
||||
Note that the architecture is more or less identical to [Stable Diffusion 1](./api/pipelines/stable_diffusion) so please refer to [this page](./api/pipelines/stable_diffusion) for API documentation.
|
||||
|
||||
- *Text-to-Image (512x512 resolution)*: [stabilityai/stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) with [`StableDiffusionPipeline`]
|
||||
- *Text-to-Image (768x768 resolution)*: [stabilityai/stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) with [`StableDiffusionPipeline`]
|
||||
- *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`]
|
||||
- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`]
|
||||
|
||||
We recommend using the [`DPMSolverMultistepScheduler`] as it's currently the fastest scheduler there is.
|
||||
|
||||
- *Text-to-Image (512x512 resolution)*:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
|
||||
repo_id = "stabilityai/stable-diffusion-2-base"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
|
||||
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "High quality photo of an astronaut riding a horse in space"
|
||||
image = pipe(prompt, num_inference_steps=25).images[0]
|
||||
image.save("astronaut.png")
|
||||
```
|
||||
|
||||
- *Text-to-Image (768x768 resolution)*:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
|
||||
repo_id = "stabilityai/stable-diffusion-2"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
|
||||
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "High quality photo of an astronaut riding a horse in space"
|
||||
image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0]
|
||||
image.save("astronaut.png")
|
||||
```
|
||||
|
||||
- *Image Inpainting (512x512 resolution)*:
|
||||
|
||||
```python
|
||||
import PIL
|
||||
import requests
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
|
||||
def download_image(url):
|
||||
response = requests.get(url)
|
||||
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = download_image(img_url).resize((512, 512))
|
||||
mask_image = download_image(mask_url).resize((512, 512))
|
||||
|
||||
repo_id = "stabilityai/stable-diffusion-2-inpainting"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
|
||||
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=25).images[0]
|
||||
|
||||
image.save("yellow_cat.png")
|
||||
```
|
||||
|
||||
- *Image Upscaling (x4 resolution resolution)*: [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) [`StableDiffusionUpscalePipeline`]
|
||||
|
||||
```python
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from diffusers import StableDiffusionUpscalePipeline
|
||||
import torch
|
||||
|
||||
# load model and scheduler
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipeline = pipeline.to("cuda")
|
||||
|
||||
# let's download an image
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
|
||||
response = requests.get(url)
|
||||
low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
low_res_img = low_res_img.resize((128, 128))
|
||||
prompt = "a white cat"
|
||||
upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
|
||||
upscaled_image.save("upsampled_cat.png")
|
||||
```
|
||||
|
||||
### How to load and use different schedulers.
|
||||
|
||||
The stable diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
|
||||
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
|
||||
|
||||
```python
|
||||
>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
|
||||
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
>>> # or
|
||||
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=euler_scheduler)
|
||||
```
|
||||
@@ -0,0 +1,90 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Safe Stable Diffusion
|
||||
|
||||
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://arxiv.org/abs/2211.05105) and mitigates the well known issue that models like Stable Diffusion that are trained on unfiltered, web-crawled datasets tend to suffer from inappropriate degeneration. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, or otherwise offensive content.
|
||||
Safe Stable Diffusion is an extension to the Stable Diffusion that drastically reduces content like this.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*Text-conditioned image generation models have recently achieved astonishing results in image quality and text alignment and are consequently employed in a fast-growing number of applications. Since they are highly data-driven, relying on billion-sized datasets randomly scraped from the internet, they also suffer, as we demonstrate, from degenerated and biased human behavior. In turn, they may even reinforce such biases. To help combat these undesired side effects, we present safe latent diffusion (SLD). Specifically, to measure the inappropriate degeneration due to unfiltered and imbalanced training sets, we establish a novel image generation test bed-inappropriate image prompts (I2P)-containing dedicated, real-world image-to-text prompts covering concepts such as nudity and violence. As our exhaustive empirical evaluation demonstrates, the introduced SLD removes and suppresses inappropriate image parts during the diffusion process, with no additional training required and no adverse effect on overall image quality or text alignment.*
|
||||
|
||||
|
||||
*Overview*:
|
||||
|
||||
| Pipeline | Tasks | Colab | Demo
|
||||
|---|---|:---:|:---:|
|
||||
| [pipeline_stable_diffusion_safe.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) | -
|
||||
|
||||
## Tips
|
||||
|
||||
- Safe Stable Diffusion may also be used with weights of [Stable Diffusion](./api/pipelines/stable_diffusion).
|
||||
|
||||
### Run Safe Stable Diffusion
|
||||
|
||||
Safe Stable Diffusion can be tested very easily with the [`StableDiffusionPipelineSafe`], and the `"AIML-TUDA/stable-diffusion-safe"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation).
|
||||
|
||||
### Interacting with the Safety Concept
|
||||
|
||||
To check and edit the currently used safety concept, use the `safety_concept` property of [`StableDiffusionPipelineSafe`]
|
||||
```python
|
||||
>>> from diffusers import StableDiffusionPipelineSafe
|
||||
|
||||
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
|
||||
>>> pipeline.safety_concept
|
||||
```
|
||||
For each image generation the active concept is also contained in [`StableDiffusionSafePipelineOutput`].
|
||||
|
||||
### Using pre-defined safety configurations
|
||||
|
||||
You may use the 4 configurations defined in the [Safe Latent Diffusion paper](https://arxiv.org/abs/2211.05105) as follows:
|
||||
|
||||
```python
|
||||
>>> from diffusers import StableDiffusionPipelineSafe
|
||||
>>> from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
|
||||
|
||||
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
|
||||
>>> prompt = "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c. leyendecker"
|
||||
>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
|
||||
```
|
||||
|
||||
The following configurations are available: `SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyConfig.STRONg`, and `SafetyConfig.MAX`.
|
||||
|
||||
### How to load and use different schedulers.
|
||||
|
||||
The safe stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
|
||||
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
|
||||
|
||||
```python
|
||||
>>> from diffusers import StableDiffusionPipelineSafe, EulerDiscreteScheduler
|
||||
|
||||
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained("AIML-TUDA/stable-diffusion-safe")
|
||||
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
>>> # or
|
||||
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("AIML-TUDA/stable-diffusion-safe", subfolder="scheduler")
|
||||
>>> pipeline = StableDiffusionPipelineSafe.from_pretrained(
|
||||
... "AIML-TUDA/stable-diffusion-safe", scheduler=euler_scheduler
|
||||
... )
|
||||
```
|
||||
|
||||
|
||||
## StableDiffusionSafePipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion_safe.StableDiffusionSafePipelineOutput
|
||||
|
||||
## StableDiffusionPipelineSafe
|
||||
[[autodoc]] StableDiffusionPipelineSafe
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
@@ -18,65 +18,56 @@ The abstract of the paper is the following:
|
||||
|
||||
*The recent advances in diffusion models have set an impressive milestone in many generation tasks. Trending works such as DALL-E2, Imagen, and Stable Diffusion have attracted great interest in academia and industry. Despite the rapid landscape changes, recent new approaches focus on extensions and performance rather than capacity, thus requiring separate models for separate tasks. In this work, we expand the existing single-flow diffusion pipeline into a multi-flow network, dubbed Versatile Diffusion (VD), that handles text-to-image, image-to-text, image-variation, and text-variation in one unified model. Moreover, we generalize VD to a unified multi-flow multimodal diffusion framework with grouped layers, swappable streams, and other propositions that can process modalities beyond images and text. Through our experiments, we demonstrate that VD and its underlying framework have the following merits: a) VD handles all subtasks with competitive quality; b) VD initiates novel extensions and applications such as disentanglement of style and semantic, image-text dual-guided generation, etc.; c) Through these experiments and applications, VD provides more semantic insights of the generated outputs.*
|
||||
|
||||
*Overview*:
|
||||
|
||||
| Pipeline | Tasks | Colab | Demo
|
||||
|---|---|:---:|:---:|
|
||||
| [pipeline_alt_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py) | *Text-to-Image Generation* | - | -
|
||||
| [pipeline_alt_diffusion_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py) | *Image-to-Image Text-Guided Generation* | - |-
|
||||
|
||||
## Tips
|
||||
|
||||
- VersatileDiffusion is conceptually very similar as [Stable Diffusion](./api/pipelines/stable_diffusion), but instead of providing just a image data stream conditioned on text, VersatileDiffusion provides both a image and text data stream and can be conditioned on both text and image.
|
||||
|
||||
- *Run VersatileDiffusion*
|
||||
### *Run VersatileDiffusion*
|
||||
|
||||
All task VersatileDiffusion can be tested very easily with the [`VersatileDiffusionPipeline`], [`VersatileDiffusionImg2ImgPipeline`] and the `"BAAI/VersatileDiffusion-m9"` checkpoint exactly in the same way it is shown in the [Conditional Image Generation Guide](./using-diffusers/conditional_image_generation) and the [Image-to-Image Generation Guide](./using-diffusers/img2img).
|
||||
You can both load the memory intensive "all-in-one" [`VersatileDiffusionPipeline`] that can run all tasks
|
||||
with the same class as shown in [`VersatileDiffusionPipeline.text_to_image`], [`VersatileDiffusionPipeline.image_variation`], and [`VersatileDiffusionPipeline.dual_guided`]
|
||||
|
||||
- *How to load and use different schedulers.*
|
||||
**or**
|
||||
|
||||
The alt diffusion pipeline uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
|
||||
You can run the individual pipelines which are much more memory efficient:
|
||||
|
||||
- *Text-to-Image*: [`VersatileDiffusionTextToImagePipeline.__call__`]
|
||||
- *Image Variation*: [`VersatileDiffusionImageVariationPipeline.__call__`]
|
||||
- *Dual Text and Image Guided Generation*: [`VersatileDiffusionDualGuidedPipeline.__call__`]
|
||||
|
||||
### *How to load and use different schedulers.*
|
||||
|
||||
The versatile diffusion pipelines uses [`DDIMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the alt diffusion pipeline such as [`PNDMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
|
||||
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
|
||||
|
||||
```python
|
||||
>>> from diffusers import VersatileDiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
>>> pipeline = VersatileDiffusionPipeline.from_pretrained("BAAI/VersatileDiffusion-m9")
|
||||
>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion")
|
||||
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
>>> # or
|
||||
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("BAAI/VersatileDiffusion-m9", subfolder="scheduler")
|
||||
>>> pipeline = VersatileDiffusionPipeline.from_pretrained("BAAI/VersatileDiffusion-m9", scheduler=euler_scheduler)
|
||||
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("shi-labs/versatile-diffusion", subfolder="scheduler")
|
||||
>>> pipeline = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", scheduler=euler_scheduler)
|
||||
```
|
||||
|
||||
|
||||
- *How to conver all use cases with multiple or single pipeline*
|
||||
|
||||
If you want to use all possible use cases in a single `DiffusionPipeline` we recommend using the `components` functionality to instantiate all components in the most memory-efficient way:
|
||||
|
||||
```python
|
||||
>>> from diffusers import (
|
||||
... VersatileDiffusionPipeline,
|
||||
... VersatileDiffusionImg2ImgPipeline,
|
||||
... )
|
||||
|
||||
>>> text2img = VersatileDiffusionPipeline.from_pretrained("BAAI/VersatileDiffusion-m9")
|
||||
>>> img2img = VersatileDiffusionImg2ImgPipeline(**text2img.components)
|
||||
|
||||
>>> # now you can use text2img(...) and img2img(...) just like the call methods of each respective pipeline
|
||||
```
|
||||
|
||||
## VersatileDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.alt_diffusion.VersatileDiffusionPipelineOutput
|
||||
|
||||
## VersatileDiffusionPipeline
|
||||
[[autodoc]] VersatileDiffusionPipeline
|
||||
|
||||
## VersatileDiffusionTextToImagePipeline
|
||||
[[autodoc]] VersatileDiffusionTextToImagePipeline
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
## VersatileDiffusionImg2ImgPipeline
|
||||
[[autodoc]] VersatileDiffusionImg2ImgPipeline
|
||||
## VersatileDiffusionImageVariationPipeline
|
||||
[[autodoc]] VersatileDiffusionImageVariationPipeline
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
## VersatileDiffusionDualGuidedPipeline
|
||||
[[autodoc]] VersatileDiffusionDualGuidedPipeline
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
|
||||
@@ -48,7 +48,14 @@ available a colab notebook to directly try them out.
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting |
|
||||
| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image |
|
||||
| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb)
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
|
||||
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
|
||||
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
|
||||
|
||||
@@ -22,6 +22,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting| [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
|
||||
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting| [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
|
||||
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - |[Stuti R.](https://github.com/kingstut) |
|
||||
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
|
||||
|
||||
|
||||
@@ -663,4 +664,65 @@ Based https://arxiv.org/abs/2208.04202, this is used for diffusion on discrete d
|
||||
from diffusers import DiffusionPipeline
|
||||
pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="bit_diffusion")
|
||||
image = pipe().images[0]
|
||||
|
||||
```
|
||||
|
||||
### Stable Diffusion with K Diffusion
|
||||
|
||||
Make sure you have @crowsonkb's https://github.com/crowsonkb/k-diffusion installed:
|
||||
|
||||
```
|
||||
pip install k-diffusion
|
||||
```
|
||||
|
||||
You can use the community pipeline as follows:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "an astronaut riding a horse on mars"
|
||||
pipe.set_sampler("sample_heun")
|
||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
|
||||
|
||||
image.save("./astronaut_heun_k_diffusion.png")
|
||||
```
|
||||
|
||||
To make sure that K Diffusion and `diffusers` yield the same results:
|
||||
|
||||
**Diffusers**:
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
seed = 33
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
|
||||
```
|
||||
|
||||

|
||||
|
||||
**K Diffusion**:
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
seed = 33
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
pipe.set_sampler("sample_euler")
|
||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ def ddpm_bit_scheduler_step(
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
predict_epsilon=True,
|
||||
prediction_type="epsilon",
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDPMSchedulerOutput, Tuple]:
|
||||
@@ -150,8 +150,8 @@ def ddpm_bit_scheduler_step(
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
|
||||
Returns:
|
||||
@@ -174,10 +174,12 @@ def ddpm_bit_scheduler_step(
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if predict_epsilon:
|
||||
if prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
else:
|
||||
elif prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction_type {prediction_type}.")
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
scale = self.bit_scale
|
||||
|
||||
@@ -78,7 +78,12 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
)
|
||||
|
||||
self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||
self.make_cutouts = MakeCutouts(feature_extractor.size)
|
||||
cut_out_size = (
|
||||
feature_extractor.size
|
||||
if isinstance(feature_extractor.size, int)
|
||||
else feature_extractor.size["shortest_edge"]
|
||||
)
|
||||
self.make_cutouts = MakeCutouts(cut_out_size)
|
||||
|
||||
set_requires_grad(self.text_encoder, False)
|
||||
set_requires_grad(self.clip_model, False)
|
||||
|
||||
@@ -110,7 +110,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
|
||||
@@ -101,7 +101,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
|
||||
@@ -469,7 +469,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
|
||||
@@ -113,7 +113,7 @@ class MultilingualStableDiffusion(DiffusionPipeline):
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
|
||||
+479
@@ -0,0 +1,479 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import is_accelerate_available, logging
|
||||
from k_diffusion.external import CompVisDenoiser
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ModelWrapper:
|
||||
def __init__(self, model, alphas_cumprod):
|
||||
self.model = model
|
||||
self.alphas_cumprod = alphas_cumprod
|
||||
|
||||
def apply_model(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs).sample
|
||||
|
||||
|
||||
class StableDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
scheduler,
|
||||
safety_checker,
|
||||
feature_extractor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
# get correct sigmas from LMS
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
model = ModelWrapper(unet, scheduler.alphas_cumprod)
|
||||
self.k_diffusion_model = CompVisDenoiser(model)
|
||||
|
||||
def set_sampler(self, scheduler_type: str):
|
||||
library = importlib.import_module("k_diffusion")
|
||||
sampling = getattr(library, "sampling")
|
||||
self.sampler = getattr(sampling, scheduler_type)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
text_embeddings = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
text_embeddings = text_embeddings[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
uncond_embeddings = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def check_inputs(self, prompt, height, width, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = True
|
||||
if guidance_scale <= 1.0:
|
||||
raise ValueError("has to use guidance_scale")
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=text_embeddings.device)
|
||||
sigmas = self.scheduler.sigmas
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
latents = latents * sigmas[0]
|
||||
self.k_diffusion_model.sigmas = self.k_diffusion_model.sigmas.to(latents.device)
|
||||
self.k_diffusion_model.log_sigmas = self.k_diffusion_model.log_sigmas.to(latents.device)
|
||||
|
||||
def model_fn(x, t):
|
||||
latent_model_input = torch.cat([x] * 2)
|
||||
|
||||
noise_pred = self.k_diffusion_model(latent_model_input, t, encoder_hidden_states=text_embeddings)
|
||||
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
return noise_pred
|
||||
|
||||
latents = self.sampler(model_fn, latents, sigmas)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
|
||||
# 10. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -42,7 +42,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
|
||||
@@ -99,7 +99,7 @@ class TextInpainting(DiffusionPipeline):
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
|
||||
@@ -135,7 +135,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
|
||||
@@ -141,7 +141,7 @@ export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_dreambooth.py \
|
||||
accelerate launch --mixed_precision="fp16" train_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--class_data_dir=$CLASS_DIR \
|
||||
@@ -157,8 +157,7 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800 \
|
||||
--mixed_precision=fp16
|
||||
--max_train_steps=800
|
||||
```
|
||||
|
||||
### Fine-tune text encoder with the UNet.
|
||||
|
||||
@@ -187,12 +187,12 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
@@ -538,9 +538,9 @@ def main(args):
|
||||
)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move text_encode and vae to gpu.
|
||||
|
||||
@@ -46,7 +46,7 @@ With `gradient_checkpointing` and `mixed_precision` it should be possible to fin
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export dataset_name="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch train_text_to_image.py \
|
||||
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$dataset_name \
|
||||
--use_ema \
|
||||
@@ -54,7 +54,6 @@ accelerate launch train_text_to_image.py \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--mixed_precision="fp16" \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
@@ -70,7 +69,7 @@ If you wish to use custom loading logic, you should modify the script, we have l
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export TRAIN_DIR="path_to_your_dataset"
|
||||
|
||||
accelerate launch train_text_to_image.py \
|
||||
accelerate launch --mixed_precision="fp16" train_text_to_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$TRAIN_DIR \
|
||||
--use_ema \
|
||||
@@ -78,7 +77,6 @@ accelerate launch train_text_to_image.py \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--mixed_precision="fp16" \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
|
||||
@@ -186,12 +186,12 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -496,9 +496,9 @@ def main():
|
||||
)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move text_encode and vae to gpu.
|
||||
|
||||
@@ -194,9 +194,10 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--predict_epsilon",
|
||||
action="store_true",
|
||||
default=True,
|
||||
"--prediction_type",
|
||||
type=str,
|
||||
default="epsilon",
|
||||
choices=["epsilon", "sample"],
|
||||
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
||||
)
|
||||
|
||||
@@ -256,13 +257,13 @@ def main(args):
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
|
||||
if accepts_predict_epsilon:
|
||||
if accepts_prediction_type:
|
||||
noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=args.ddpm_num_steps,
|
||||
beta_schedule=args.ddpm_beta_schedule,
|
||||
predict_epsilon=args.predict_epsilon,
|
||||
prediction_type=args.prediction_type,
|
||||
)
|
||||
else:
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
||||
@@ -365,9 +366,9 @@ def main(args):
|
||||
# Predict the noise residual
|
||||
model_output = model(noisy_images, timesteps).sample
|
||||
|
||||
if args.predict_epsilon:
|
||||
if args.prediction_type == "epsilon":
|
||||
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
||||
else:
|
||||
elif args.prediction_type == "sample":
|
||||
alpha_t = _extract_into_tensor(
|
||||
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
|
||||
)
|
||||
@@ -376,6 +377,8 @@ def main(args):
|
||||
model_output, clean_images, reduction="none"
|
||||
) # use SNR weighting from distillation paper
|
||||
loss = loss.mean()
|
||||
else:
|
||||
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
|
||||
@@ -211,6 +211,7 @@ def create_unet_diffusers_config(original_config):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
model_params = original_config.model.params
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
@@ -230,7 +231,7 @@ def create_unet_diffusers_config(original_config):
|
||||
resolution //= 2
|
||||
|
||||
config = dict(
|
||||
sample_size=unet_params.image_size,
|
||||
sample_size=model_params.image_size,
|
||||
in_channels=unet_params.in_channels,
|
||||
out_channels=unet_params.out_channels,
|
||||
down_block_types=tuple(down_block_types),
|
||||
|
||||
@@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
)
|
||||
del pipeline.safety_checker
|
||||
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
|
||||
feature_extractor = pipeline.feature_extractor
|
||||
else:
|
||||
safety_checker = None
|
||||
feature_extractor = None
|
||||
|
||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
|
||||
@@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=pipeline.feature_extractor,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=safety_checker is not None,
|
||||
)
|
||||
|
||||
onnx_pipeline.save_pretrained(output_path)
|
||||
|
||||
@@ -212,7 +212,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.8.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.9.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -9,7 +9,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
__version__ = "0.8.0.dev0"
|
||||
__version__ = "0.9.0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
@@ -69,12 +69,14 @@ if is_torch_available() and is_transformers_available():
|
||||
AltDiffusionPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPipelineSafe,
|
||||
StableDiffusionUpscalePipeline,
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageToTextPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
|
||||
@@ -80,20 +80,21 @@ class ConfigMixin:
|
||||
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
||||
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
||||
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
||||
overridden by parent class).
|
||||
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent
|
||||
class).
|
||||
overridden by subclass).
|
||||
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
||||
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
|
||||
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
||||
subclass).
|
||||
"""
|
||||
config_name = None
|
||||
ignore_for_config = []
|
||||
has_compatibles = False
|
||||
|
||||
_deprecated_kwargs = []
|
||||
|
||||
def register_to_config(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
||||
kwargs["_class_name"] = self.__class__.__name__
|
||||
kwargs["_diffusers_version"] = __version__
|
||||
|
||||
# Special case for `kwargs` used in deprecation warning added to schedulers
|
||||
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
||||
# or solve in a more general way.
|
||||
@@ -198,6 +199,11 @@ class ConfigMixin:
|
||||
if "dtype" in unused_kwargs:
|
||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||
|
||||
# add possible deprecated kwargs
|
||||
for deprecated_kwarg in cls._deprecated_kwargs:
|
||||
if deprecated_kwarg in unused_kwargs:
|
||||
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
||||
|
||||
# Return model and optionally state and/or unused_kwargs
|
||||
model = cls(**init_dict)
|
||||
|
||||
@@ -462,7 +468,7 @@ class ConfigMixin:
|
||||
unused_kwargs = {**config_dict, **kwargs}
|
||||
|
||||
# 7. Define "hidden" config parameters that were saved for compatible classes
|
||||
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")}
|
||||
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
|
||||
|
||||
return init_dict, unused_kwargs, hidden_config_dict
|
||||
|
||||
@@ -493,6 +499,9 @@ class ConfigMixin:
|
||||
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
||||
"""
|
||||
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
||||
config_dict["_class_name"] = self.__class__.__name__
|
||||
config_dict["_diffusers_version"] = __version__
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
@@ -520,7 +529,7 @@ def register_to_config(init):
|
||||
def inner_init(self, *args, **kwargs):
|
||||
# Ignore private kwargs in the init.
|
||||
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
||||
init(self, *args, **init_kwargs)
|
||||
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
||||
if not isinstance(self, ConfigMixin):
|
||||
raise RuntimeError(
|
||||
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
||||
@@ -545,7 +554,9 @@ def register_to_config(init):
|
||||
if k not in ignore and k not in new_kwargs
|
||||
}
|
||||
)
|
||||
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
||||
getattr(self, "register_to_config")(**new_kwargs)
|
||||
init(self, *args, **init_kwargs)
|
||||
|
||||
return inner_init
|
||||
|
||||
@@ -562,7 +573,7 @@ def flax_register_to_config(cls):
|
||||
)
|
||||
|
||||
# Ignore private kwargs in the init. Retrieve all passed attributes
|
||||
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
||||
init_kwargs = {k: v for k, v in kwargs.items()}
|
||||
|
||||
# Retrieve default values
|
||||
fields = dataclasses.fields(self)
|
||||
|
||||
@@ -89,6 +89,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
x = x + scale * grad
|
||||
x = self.reset_x0(x, conditions, self.action_dim)
|
||||
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
||||
# TODO: set prediction_type when instantiating the model
|
||||
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
||||
|
||||
# apply conditions to the trajectory
|
||||
|
||||
@@ -332,7 +332,7 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
@@ -448,7 +448,7 @@ class ModelMixin(torch.nn.Module):
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
model, unused_kwargs = cls.from_config(
|
||||
config, unused_kwargs = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
@@ -462,6 +462,7 @@ class ModelMixin(torch.nn.Module):
|
||||
device_map=device_map,
|
||||
**kwargs,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
@@ -482,7 +483,7 @@ class ModelMixin(torch.nn.Module):
|
||||
"error_msgs": [],
|
||||
}
|
||||
else:
|
||||
model, unused_kwargs = cls.from_config(
|
||||
config, unused_kwargs = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
@@ -496,6 +497,7 @@ class ModelMixin(torch.nn.Module):
|
||||
device_map=device_map,
|
||||
**kwargs,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file)
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@@ -22,7 +23,7 @@ from torch import nn
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import CONFIG_NAME, BaseOutput
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@@ -98,8 +99,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
@@ -125,7 +129,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
if use_linear_projection:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
@@ -151,6 +158,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
@@ -158,7 +166,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 4. Define output layers
|
||||
if self.is_input_continuous:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
if use_linear_projection:
|
||||
self.proj_out = nn.Linear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
@@ -190,10 +201,16 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
if self.is_input_continuous:
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
|
||||
@@ -203,8 +220,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
@@ -284,22 +310,52 @@ class AttentionBlock(nn.Module):
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
# transpose
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
scale = 1 / math.sqrt(self.channels / self.num_heads)
|
||||
|
||||
# get scores
|
||||
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
|
||||
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
|
||||
if self.num_heads > 1:
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
|
||||
# or reformulate this into a 3D problem?
|
||||
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
|
||||
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
|
||||
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
|
||||
else:
|
||||
query_states, key_states, value_states = query_proj, key_proj, value_proj
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
query_states.shape[0],
|
||||
query_states.shape[1],
|
||||
key_states.shape[1],
|
||||
dtype=query_states.dtype,
|
||||
device=query_states.device,
|
||||
),
|
||||
query_states,
|
||||
key_states.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=scale,
|
||||
)
|
||||
|
||||
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.matmul(attention_probs, value_states)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
if self.num_heads > 1:
|
||||
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
|
||||
# or reformulate this into a 3D problem?
|
||||
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
|
||||
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
|
||||
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
|
||||
hidden_states = torch.matmul(attention_probs, value_states)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
else:
|
||||
hidden_states = torch.bmm(attention_probs, value_states)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
@@ -337,14 +393,17 @@ class BasicTransformerBlock(nn.Module):
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
self.attn2 = CrossAttention(
|
||||
@@ -366,6 +425,16 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
|
||||
# if xformers is installed try to use memory_efficient_attention by default
|
||||
if is_xformers_available():
|
||||
try:
|
||||
self._set_use_memory_efficient_attention_xformers(True)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
"Could not enable memory efficient attention. Make sure xformers is installed"
|
||||
f" correctly and a GPU is available: {e}"
|
||||
)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
self.attn1._slice_size = slice_size
|
||||
self.attn2._slice_size = slice_size
|
||||
@@ -401,7 +470,11 @@ class BasicTransformerBlock(nn.Module):
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
||||
|
||||
if self.only_cross_attention:
|
||||
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
|
||||
else:
|
||||
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = (
|
||||
@@ -507,19 +580,17 @@ class CrossAttention(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
def _attention(self, query, key, value):
|
||||
# TODO: use baddbmm for better performance
|
||||
if query.device.type == "mps":
|
||||
# Better performance on mps (~20-25%)
|
||||
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
|
||||
else:
|
||||
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
# compute attention output
|
||||
|
||||
if query.device.type == "mps":
|
||||
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
|
||||
else:
|
||||
hidden_states = torch.matmul(attention_probs, value)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
@@ -534,21 +605,15 @@ class CrossAttention(nn.Module):
|
||||
for i in range(hidden_states.shape[0] // slice_size):
|
||||
start_idx = i * slice_size
|
||||
end_idx = (i + 1) * slice_size
|
||||
if query.device.type == "mps":
|
||||
# Better performance on mps (~20-25%)
|
||||
attn_slice = (
|
||||
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
attn_slice = (
|
||||
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
|
||||
) # TODO: use baddbmm for better performance
|
||||
attn_slice = torch.baddbmm(
|
||||
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx].transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
attn_slice = attn_slice.softmax(dim=-1)
|
||||
if query.device.type == "mps":
|
||||
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
|
||||
else:
|
||||
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
@@ -731,12 +796,18 @@ class DualTransformer2DModel(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
# Variables that can be set by a pipeline:
|
||||
|
||||
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
||||
self.mix_ratio = 0.5
|
||||
|
||||
# The shape of `encoder_hidden_states` is expected to be
|
||||
# `(batch_size, num_condition_tokens[0]+num_condition_tokens[1], num_features)`
|
||||
self.num_condition_tokens = (77, 257)
|
||||
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
||||
self.condition_lengths = [77, 257]
|
||||
|
||||
# Which transformer to use to encode which condition.
|
||||
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
||||
self.transformer_index_for_condition = [1, 0]
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
|
||||
"""
|
||||
@@ -763,10 +834,13 @@ class DualTransformer2DModel(nn.Module):
|
||||
tokens_start = 0
|
||||
for i in range(2):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.num_condition_tokens[i]]
|
||||
encoded_state = self.transformers[i](input_states, condition_state, timestep, return_dict)[0]
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||
transformer_index = self.transformer_index_for_condition[i]
|
||||
encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
|
||||
0
|
||||
]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
tokens_start += self.num_condition_tokens[i]
|
||||
tokens_start += self.condition_lengths[i]
|
||||
|
||||
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
||||
output_states = output_states + input_states
|
||||
|
||||
@@ -43,8 +43,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
implements for all the model (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
|
||||
Input sample size.
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
@@ -71,7 +71,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[int] = None,
|
||||
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
center_input_sample: bool = False,
|
||||
@@ -209,6 +209,11 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
@@ -242,9 +247,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
sample = upsample_block(sample, res_samples, emb)
|
||||
|
||||
# 6. post-process
|
||||
# make sure hidden states is in float32
|
||||
# when running in half-precision
|
||||
sample = self.conv_norm_out(sample.float()).type(sample.dtype)
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlock, Transformer2DModel, DualTransformer2DModel
|
||||
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
@@ -33,6 +33,8 @@ def get_down_block(
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock2D":
|
||||
@@ -76,6 +78,8 @@ def get_down_block(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
elif down_block_type == "SkipDownBlock2D":
|
||||
return SkipDownBlock2D(
|
||||
@@ -140,6 +144,8 @@ def get_up_block(
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock2D":
|
||||
@@ -169,6 +175,9 @@ def get_up_block(
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
elif up_block_type == "AttnUpBlock2D":
|
||||
return AttnUpBlock2D(
|
||||
@@ -245,7 +254,6 @@ class UNetMidBlock2D(nn.Module):
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -326,7 +334,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
**kwargs,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -352,16 +360,29 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
@@ -381,15 +402,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a divisor of "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"Chunk_size {slice_size} has to be smaller or equal to "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
@@ -510,6 +533,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -534,7 +559,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
if dual_cross_attention is False:
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
@@ -543,6 +568,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -573,15 +600,17 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a divisor of "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"Chunk_size {slice_size} has to be smaller or equal to "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
@@ -1106,6 +1135,9 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -1132,16 +1164,30 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
@@ -1153,15 +1199,17 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a divisor of "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"Chunk_size {slice_size} has to be smaller or equal to "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
|
||||
@@ -56,11 +56,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): The size of the input sample.
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample.
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
@@ -97,6 +98,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
@@ -105,8 +107,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
attention_head_dim: int = 8,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -122,10 +126,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
@@ -144,9 +158,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -159,9 +175,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
@@ -169,6 +186,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
@@ -196,8 +215,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -208,15 +229,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
|
||||
head_dims = self.config.attention_head_dim
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a divisor of "
|
||||
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > self.config.attention_head_dim:
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"Chunk_size {slice_size} has to be smaller or equal to "
|
||||
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for block in self.down_blocks:
|
||||
@@ -249,14 +272,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
(batch_size, sequence_length, hidden_size) encoder hidden states
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -302,6 +325,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.config.num_class_embeds is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
|
||||
@@ -411,13 +411,13 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
f" {expected_class_obj}"
|
||||
)
|
||||
elif passed_class_obj[name] is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
|
||||
f" that this might lead to problems when using {pipeline_class} and is not recommended."
|
||||
)
|
||||
sub_model_should_be_defined = False
|
||||
else:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
" has the correct type"
|
||||
)
|
||||
|
||||
@@ -129,10 +129,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
Class attributes:
|
||||
|
||||
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
|
||||
- **config_name** (`str`) -- name of the config file that will store the class and module names of all
|
||||
components of the diffusion pipeline.
|
||||
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
|
||||
passed for the pipeline to function (should be overridden by subclasses).
|
||||
"""
|
||||
config_name = "model_index.json"
|
||||
_optional_components = []
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
# import it here to avoid circular import
|
||||
@@ -184,12 +187,19 @@ class DiffusionPipeline(ConfigMixin):
|
||||
model_index_dict.pop("_diffusers_version")
|
||||
model_index_dict.pop("_module", None)
|
||||
|
||||
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
||||
|
||||
def is_saveable_module(name, value):
|
||||
if name not in expected_modules:
|
||||
return False
|
||||
if name in self._optional_components and value[0] is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
if sub_model is None:
|
||||
# edge case for saving a pipeline with safety_checker=None
|
||||
continue
|
||||
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
@@ -405,7 +415,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
@@ -523,38 +533,47 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
# extract them here
|
||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
|
||||
# define init kwargs
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
init_kwargs = {}
|
||||
# remove `null` components
|
||||
def load_module(name, value):
|
||||
if value[0] is None:
|
||||
return False
|
||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||
)
|
||||
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name is None:
|
||||
# edge case for when the pipeline was saved with safety_checker=None
|
||||
init_kwargs[name] = None
|
||||
continue
|
||||
|
||||
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if class_name.startswith("Flax"):
|
||||
class_name = class_name[4:]
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
sub_model_should_be_defined = True
|
||||
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
if name in passed_class_obj:
|
||||
# 1. check that passed_class_obj has correct parent class
|
||||
if not is_pipeline_module and passed_class_obj[name] is not None:
|
||||
if not is_pipeline_module:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
@@ -570,14 +589,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||
f" {expected_class_obj}"
|
||||
)
|
||||
elif passed_class_obj[name] is None:
|
||||
logger.warn(
|
||||
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
|
||||
f" that this might lead to problems when using {pipeline_class} and is not recommended."
|
||||
)
|
||||
sub_model_should_be_defined = False
|
||||
else:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
" has the correct type"
|
||||
)
|
||||
@@ -597,7 +610,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
if loaded_sub_model is None and sub_model_should_be_defined:
|
||||
if loaded_sub_model is None:
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
@@ -651,11 +664,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 4. Potentially add passed objects if expected
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||
for module in missing_modules:
|
||||
init_kwargs[module] = passed_class_obj[module]
|
||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||
elif len(missing_modules) > 0:
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys()))
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||
raise ValueError(
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
@@ -664,6 +679,14 @@ class DiffusionPipeline(ConfigMixin):
|
||||
model = pipeline_class(**init_kwargs)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _get_signature_keys(obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - set(["self"])
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
@property
|
||||
def components(self) -> Dict[str, Any]:
|
||||
r"""
|
||||
@@ -688,8 +711,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Returns:
|
||||
A dictionaly containing all the modules needed to initialize the pipeline.
|
||||
"""
|
||||
components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
|
||||
expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
|
||||
expected_modules, optional_parameters = self._get_signature_keys(self)
|
||||
components = {
|
||||
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
|
||||
}
|
||||
|
||||
if set(components.keys()) != expected_modules:
|
||||
raise ValueError(
|
||||
|
||||
@@ -19,14 +19,16 @@ if is_torch_available() and is_transformers_available():
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .stable_diffusion import (
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageToTextPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union
|
||||
import torch
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -67,6 +68,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -84,6 +86,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -114,8 +117,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
@@ -124,6 +127,33 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -133,6 +163,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
@@ -166,9 +198,14 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
@@ -192,10 +229,15 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
@@ -370,7 +412,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
@@ -390,8 +432,8 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -411,9 +453,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -459,6 +501,9 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, callback_steps)
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -80,6 +81,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -97,6 +99,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -127,8 +130,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
@@ -137,6 +140,33 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -146,6 +176,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
@@ -161,9 +193,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
@@ -187,10 +224,15 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
|
||||
@@ -89,7 +89,11 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if isinstance(self.unet.sample_size, int):
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
|
||||
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator)
|
||||
|
||||
@@ -70,14 +70,14 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
generated images.
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.scheduler.config)
|
||||
new_config["predict_epsilon"] = predict_epsilon
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
self.scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
@@ -94,7 +94,11 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if isinstance(self.unet.sample_size, int):
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.in_channels, *self.unet.sample_size)
|
||||
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator)
|
||||
@@ -110,9 +114,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. compute previous image: x_t -> x_t-1
|
||||
image = self.scheduler.step(
|
||||
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
|
||||
).prev_sample
|
||||
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -60,13 +60,14 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: Optional[int] = 256,
|
||||
width: Optional[int] = 256,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 1.0,
|
||||
eta: Optional[float] = 0.0,
|
||||
@@ -79,9 +80,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 256):
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 256):
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -106,6 +107,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
|
||||
@@ -6,7 +6,14 @@ import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
is_flax_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -33,8 +40,14 @@ if is_transformers_available() and is_torch_available():
|
||||
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
|
||||
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
|
||||
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
|
||||
else:
|
||||
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
|
||||
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -132,6 +133,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -142,6 +144,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -159,8 +162,8 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
@@ -169,6 +172,32 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -178,6 +207,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
@@ -194,9 +224,14 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
@@ -222,10 +257,15 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
|
||||
@@ -23,6 +23,7 @@ import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training.common_utils import shard
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
@@ -34,7 +35,7 @@ from ...schedulers import (
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
)
|
||||
from ...utils import logging
|
||||
from ...utils import deprecate, logging
|
||||
from . import FlaxStableDiffusionPipelineOutput
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
@@ -88,7 +89,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
self.dtype = dtype
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
@@ -97,6 +98,27 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -106,6 +128,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def prepare_inputs(self, prompt: Union[str, List[str]]):
|
||||
if not isinstance(prompt, (str, list)):
|
||||
@@ -160,13 +183,17 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.PRNGKey,
|
||||
num_inference_steps: int = 50,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
latents: Optional[jnp.array] = None,
|
||||
debug: bool = False,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
):
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
@@ -188,7 +215,12 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if latents is None:
|
||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||
else:
|
||||
@@ -249,8 +281,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.PRNGKey,
|
||||
num_inference_steps: int = 50,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
latents: jnp.array = None,
|
||||
return_dict: bool = True,
|
||||
@@ -265,9 +297,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -302,6 +334,10 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if jit:
|
||||
images = _p_generate(
|
||||
self,
|
||||
|
||||
@@ -41,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
@@ -51,6 +53,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -81,6 +84,22 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
@@ -91,6 +110,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
r"""
|
||||
|
||||
@@ -77,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
@@ -87,6 +89,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -117,7 +120,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
@@ -127,6 +130,12 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
@@ -137,6 +146,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
|
||||
@@ -90,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae_encoder: OnnxRuntimeModel,
|
||||
@@ -100,6 +102,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
|
||||
@@ -131,7 +134,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
@@ -141,6 +144,12 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
@@ -151,6 +160,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
@@ -236,8 +246,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
prompt: Union[str, List[str]],
|
||||
image: PIL.Image.Image,
|
||||
mask_image: PIL.Image.Image,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -312,6 +322,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
|
||||
+35
-4
@@ -5,6 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -27,11 +28,11 @@ def preprocess(image):
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def preprocess_mask(mask):
|
||||
def preprocess_mask(mask, scale_factor=8):
|
||||
mask = mask.convert("L")
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
|
||||
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL.Image.NEAREST)
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
@@ -86,6 +87,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -116,7 +118,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
@@ -126,6 +128,33 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
@@ -136,6 +165,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
@@ -341,7 +372,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
# preprocess mask
|
||||
if not isinstance(mask_image, np.ndarray):
|
||||
mask_image = preprocess_mask(mask_image)
|
||||
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
||||
mask_image = mask_image.astype(latents_dtype)
|
||||
mask = np.concatenate([mask_image] * num_images_per_prompt, axis=0)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union
|
||||
import torch
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -66,6 +67,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -83,6 +85,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -113,8 +116,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
@@ -123,6 +126,33 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -132,6 +162,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
@@ -165,9 +197,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
@@ -191,10 +228,15 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
@@ -369,7 +411,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
@@ -389,8 +431,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -410,9 +452,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -458,6 +500,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, callback_steps)
|
||||
|
||||
+478
@@ -0,0 +1,478 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...utils import deprecate, logging
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline to generate variations from an input image using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
image_encoder=image_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.image_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeddings = self.image_encoder(image).image_embeds
|
||||
image_embeddings = image_embeddings.unsqueeze(1)
|
||||
|
||||
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = image_embeddings.shape
|
||||
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_embeddings = torch.zeros_like(image_embeddings)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
|
||||
|
||||
return image_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(self, image, height, width, callback_steps):
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
||||
The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
|
||||
configuration of
|
||||
[this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
|
||||
`CLIPFeatureExtractor`
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(image, height, width, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
batch_size = 1
|
||||
elif isinstance(image, list):
|
||||
batch_size = len(image)
|
||||
else:
|
||||
batch_size = image.shape[0]
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input image
|
||||
image_embeddings = self._encode_image(image, device, num_images_per_prompt, do_classifier_free_guidance)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
image_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
|
||||
|
||||
# 10. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -78,6 +79,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
|
||||
def __init__(
|
||||
@@ -96,6 +98,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -126,8 +129,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
@@ -136,6 +139,33 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -145,6 +175,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
@@ -161,9 +193,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
@@ -189,10 +226,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -150,6 +151,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -160,6 +162,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -191,8 +194,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
new_config["skip_prk_steps"] = True
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
@@ -201,6 +204,33 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -210,6 +240,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
@@ -226,9 +258,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
@@ -254,10 +291,15 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
@@ -459,7 +501,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
@@ -481,7 +523,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
@@ -509,8 +553,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -538,9 +582,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
||||
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
||||
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -586,6 +630,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, height, width, callback_steps)
|
||||
|
||||
+51
-9
@@ -20,6 +20,7 @@ import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
@@ -51,11 +52,11 @@ def preprocess_image(image):
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def preprocess_mask(mask):
|
||||
def preprocess_mask(mask, scale_factor=8):
|
||||
mask = mask.convert("L")
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
|
||||
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
@@ -91,6 +92,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
|
||||
def __init__(
|
||||
@@ -109,6 +111,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -139,8 +142,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
@@ -149,6 +152,33 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -158,6 +188,8 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
@@ -174,9 +206,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
@@ -202,10 +239,15 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
|
||||
# fix by only offloading self.safety_checker for now
|
||||
cpu_offload(self.safety_checker.vision_model, device)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
@@ -532,7 +574,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
init_image = preprocess_image(init_image)
|
||||
|
||||
if not isinstance(mask_image, torch.FloatTensor):
|
||||
mask_image = preprocess_mask(mask_image)
|
||||
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
+248
-165
@@ -13,120 +13,86 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import PIL
|
||||
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection, GPT2Tokenizer
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention import Transformer2DModel
|
||||
from ...pipeline_utils import BaseOutput, DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import is_accelerate_available, logging
|
||||
from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector
|
||||
from .modeling_text_unet import UNetFlatConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for text generation pipelines.
|
||||
def preprocess(image):
|
||||
# resize to multiple of 64
|
||||
width, height = image.size
|
||||
width = width - width % 64
|
||||
height = height - height % 64
|
||||
image = image.resize((width, height))
|
||||
|
||||
Args:
|
||||
text (`List[str]` or `np.ndarray`)
|
||||
List of generated text of length `batch_size` or a numpy array of tokens of shape `(batch_size,
|
||||
num_tokens)`.
|
||||
"""
|
||||
|
||||
text: Union[List[str], np.ndarray]
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
return image
|
||||
|
||||
|
||||
class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-guided image super-resolution using Stable Diffusion 2.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Parameters:
|
||||
vqvae ([`VQModel`]):
|
||||
Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
|
||||
bert ([`LDMBertModel`]):
|
||||
Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture.
|
||||
tokenizer (`transformers.BertTokenizer`):
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
low_res_scheduler ([`SchedulerMixin`]):
|
||||
A scheduler used to add initial noise to the low res conditioning image. It must be an instance of
|
||||
[`DDPMScheduler`].
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
image_feature_extractor: CLIPFeatureExtractor
|
||||
image_encoder: CLIPVisionModelWithProjection
|
||||
image_unet: UNet2DConditionModel
|
||||
text_unet: UNetFlatConditionModel
|
||||
vae: AutoencoderKL
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
image_unet: UNet2DConditionModel,
|
||||
text_unet: UNetFlatConditionModel,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
low_res_scheduler: DDPMScheduler,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
image_feature_extractor=image_feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
image_unet=image_unet,
|
||||
text_unet=text_unet,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.register_to_config(max_noise_level=max_noise_level)
|
||||
|
||||
self.text_vae_decoder = GPT2OptimusForLatentConnector.from_pretrained("fusing/gpt2_optimus")
|
||||
self.text_vae_tokenizer = GPT2Tokenizer.from_pretrained("fusing/gpt2_optimus")
|
||||
|
||||
def swap_unet_attention_blocks(self):
|
||||
for name, module in self.image_unet.named_modules():
|
||||
if isinstance(module, Transformer2DModel):
|
||||
parent_name, index = name.rsplit(".", 1)
|
||||
index = int(index)
|
||||
self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = (
|
||||
self.text_unet.get_submodule(parent_name)[index],
|
||||
self.image_unet.get_submodule(parent_name)[index],
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.image_unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.image_unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -141,10 +107,15 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
self.image_unet.set_attention_slice(slice_size)
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
@@ -168,21 +139,41 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]:
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"):
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.image_unet.modules():
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
@@ -191,6 +182,7 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
@@ -208,39 +200,53 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
|
||||
def normalize_embeddings(encoder_output):
|
||||
embeds = self.image_encoder.vision_model.post_layernorm(encoder_output.last_hidden_state)
|
||||
embeds = self.image_encoder.visual_projection(embeds)
|
||||
embeds_pooled = embeds[:, 0:1]
|
||||
embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
|
||||
return embeds
|
||||
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
# prompt = [(np.asarray(prompt) / 255)]
|
||||
image_input = self.image_feature_extractor(images=prompt, return_tensors="pt")
|
||||
image_embeddings = self.image_encoder(image_input.pixel_values.to(self.device))
|
||||
image_embeddings = normalize_embeddings(image_embeddings)
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
|
||||
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = image_embeddings.shape
|
||||
image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
text_embeddings = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
text_embeddings = text_embeddings[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_images: List[str]
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, PIL.Image.Image):
|
||||
uncond_images = [negative_prompt]
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
@@ -248,11 +254,27 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_images = negative_prompt
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
|
||||
uncond_embeddings = self.image_encoder(uncond_images.pixel_values.to(self.device))
|
||||
uncond_embeddings = normalize_embeddings(uncond_embeddings)
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
uncond_embeddings = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
@@ -260,18 +282,11 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and conditional embeddings into a single batch
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
image_embeddings = torch.cat([uncond_embeddings, image_embeddings])
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return image_embeddings
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = latents.reshape(latents.shape[:-2])
|
||||
self.text_vae_decoder = self.text_vae_decoder.to(self._execution_device)
|
||||
bos_token = self.text_vae_tokenizer.bos_token_id
|
||||
output = self.text_vae_decoder.generate(bos_token_id=bos_token, past=latents)
|
||||
return output
|
||||
return text_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
@@ -291,9 +306,47 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(self, image, callback_steps):
|
||||
if not isinstance(image, PIL.Image.Image) and not isinstance(image, torch.Tensor):
|
||||
raise ValueError(f"`image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(image)}")
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def check_inputs(self, prompt, image, noise_level, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
|
||||
)
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor
|
||||
if isinstance(image, list) or isinstance(image, torch.Tensor):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = len(prompt)
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
image_batch_size = image.shape[0]
|
||||
if batch_size != image_batch_size:
|
||||
raise ValueError(
|
||||
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
|
||||
" Please make sure that passed `prompt` matches the batch size of `image`."
|
||||
)
|
||||
|
||||
# check noise level
|
||||
if noise_level > self.config.max_noise_level:
|
||||
raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
@@ -303,8 +356,8 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, 1, 1)
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
@@ -323,26 +376,29 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "str",
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
|
||||
The image prompt or prompts to guide the image generation.
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
|
||||
`Image`, or tensor representing an image batch which will be upscaled. *
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
@@ -388,11 +444,11 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(image, callback_steps)
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, image, noise_level, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image)
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@@ -400,69 +456,96 @@ class VersatileDiffusionImageToTextPipeline(DiffusionPipeline):
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
image_embeddings = self._encode_prompt(
|
||||
image, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
# 4. Preprocess image
|
||||
image = [image] if isinstance(image, PIL.Image.Image) else image
|
||||
if isinstance(image, list):
|
||||
image = [preprocess(img) for img in image]
|
||||
image = torch.cat(image, dim=0)
|
||||
image = image.to(dtype=text_embeddings.dtype, device=device)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.text_unet.in_channels[0]
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
noise = torch.randn(image.shape, generator=generator, device="cpu", dtype=text_embeddings.dtype).to(device)
|
||||
else:
|
||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
image = torch.cat([image] * 2) if do_classifier_free_guidance else image
|
||||
noise_level = torch.cat([noise_level] * 2) if do_classifier_free_guidance else noise_level
|
||||
|
||||
# 6. Prepare latent variables
|
||||
height, width = image.shape[2:]
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
image_embeddings.dtype,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs.
|
||||
# 7. Check that sizes of image and latents match
|
||||
num_channels_image = image.shape[1]
|
||||
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Swap the attention blocks between the image and text UNets
|
||||
self.swap_unet_attention_blocks()
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# 9. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
print("latent_model_input", latent_model_input.abs().sum())
|
||||
print("timestep", t)
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = torch.cat([latent_model_input, image], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.text_unet(latent_model_input, t, encoder_hidden_states=image_embeddings).sample
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t, encoder_hidden_states=text_embeddings, class_labels=noise_level
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
print("e_t", noise_pred.abs().sum())
|
||||
print("e_t[3,3]", noise_pred[0, :5, 0, 0])
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
print("latents", latents.abs().sum())
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Swap the attention blocks backs in case the UNets are reused in another pipeline
|
||||
self.swap_unet_attention_blocks()
|
||||
|
||||
# 10. Post-processing
|
||||
text = self.decode_latents(latents)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
self.vae.to(dtype=torch.float32)
|
||||
image = self.decode_latents(latents.float())
|
||||
|
||||
# 11. Convert to strings
|
||||
if output_type == "str":
|
||||
text = self.text_vae_tokenizer.batch_decode(text)
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (text,)
|
||||
return (image,)
|
||||
|
||||
return TextPipelineOutput(text=text)
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -0,0 +1,72 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import BaseOutput, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class SafetyConfig(object):
|
||||
WEAK = {
|
||||
"sld_warmup_steps": 15,
|
||||
"sld_guidance_scale": 20,
|
||||
"sld_threshold": 0.0,
|
||||
"sld_momentum_scale": 0.0,
|
||||
"sld_mom_beta": 0.0,
|
||||
}
|
||||
MEDIUM = {
|
||||
"sld_warmup_steps": 10,
|
||||
"sld_guidance_scale": 1000,
|
||||
"sld_threshold": 0.01,
|
||||
"sld_momentum_scale": 0.3,
|
||||
"sld_mom_beta": 0.4,
|
||||
}
|
||||
STRONG = {
|
||||
"sld_warmup_steps": 7,
|
||||
"sld_guidance_scale": 2000,
|
||||
"sld_threshold": 0.025,
|
||||
"sld_momentum_scale": 0.5,
|
||||
"sld_mom_beta": 0.7,
|
||||
}
|
||||
MAX = {
|
||||
"sld_warmup_steps": 0,
|
||||
"sld_guidance_scale": 5000,
|
||||
"sld_threshold": 1.0,
|
||||
"sld_momentum_scale": 0.5,
|
||||
"sld_mom_beta": 0.7,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableDiffusionSafePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Safe Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, or `None` if safety checking could not be performed.
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images that were flagged by the safety checker any may contain "not-safe-for-work"
|
||||
(nsfw) content, or `None` if no safety check was performed or no images were flagged.
|
||||
applied_safety_concept (`str`)
|
||||
The safety concept that was applied for safety guidance, or `None` if safety guidance was disabled
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_content_detected: Optional[List[bool]]
|
||||
unsafe_images: Optional[Union[List[PIL.Image.Image], np.ndarray]]
|
||||
applied_safety_concept: Optional[str]
|
||||
|
||||
|
||||
if is_transformers_available() and is_torch_available():
|
||||
from .pipeline_stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .safety_checker import SafeStableDiffusionSafetyChecker
|
||||
@@ -0,0 +1,757 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from packaging import version
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...utils import deprecate, is_accelerate_available, logging
|
||||
from . import StableDiffusionSafePipelineOutput
|
||||
from .safety_checker import SafeStableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Safe Latent Diffusion.
|
||||
|
||||
The implementation is based on the [`StableDiffusionPipeline`]
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
],
|
||||
safety_checker: SafeStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
safety_concept: Optional[str] = (
|
||||
"an image showing hate, harassment, violence, suffering, humiliation, harm, suicide, sexual, nudity,"
|
||||
" bodily fluids, blood, obscene gestures, illegal activity, drug use, theft, vandalism, weapons, child"
|
||||
" abuse, brutality, cruelty"
|
||||
)
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self._safety_text_concept = safety_concept
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
@property
|
||||
def safety_concept(self):
|
||||
r"""
|
||||
Getter method for the safety concept used with SLD
|
||||
|
||||
Returns:
|
||||
`str`: The text describing the safety concept
|
||||
"""
|
||||
return self._safety_text_concept
|
||||
|
||||
@safety_concept.setter
|
||||
def safety_concept(self, concept):
|
||||
r"""
|
||||
Setter method for the safety concept used with SLD
|
||||
|
||||
Args:
|
||||
concept (`str`):
|
||||
The text of the new safety concept
|
||||
"""
|
||||
self._safety_text_concept = concept
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
enable_safety_guidance,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list(int)`):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
|
||||
|
||||
if not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
text_embeddings = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
text_embeddings = text_embeddings[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
uncond_embeddings = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# Encode the safety concept text
|
||||
if enable_safety_guidance:
|
||||
safety_concept_input = self.tokenizer(
|
||||
[self._safety_text_concept],
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
safety_embeddings = self.text_encoder(safety_concept_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate safety embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = safety_embeddings.shape[1]
|
||||
safety_embeddings = safety_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
safety_embeddings = safety_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance + sld, we need to do three forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing three forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, safety_embeddings])
|
||||
|
||||
else:
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def run_safety_checker(self, image, device, dtype, enable_safety_guidance):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
flagged_images = None
|
||||
if any(has_nsfw_concept):
|
||||
logger.warning(
|
||||
"Potential NSFW content was detected in one or more images. A black image will be returned"
|
||||
" instead."
|
||||
f" {'You may look at this images in the `unsafe_images` variable of the output at your own discretion.' if enable_safety_guidance else 'Try again with a different prompt and/or seed.'} "
|
||||
)
|
||||
flagged_images = np.zeros((2, *image.shape[1:]))
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concept):
|
||||
if has_nsfw_concept:
|
||||
flagged_images[idx] = image[idx]
|
||||
image[idx] = np.zeros(image[idx].shape) # black image
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
flagged_images = None
|
||||
return image, has_nsfw_concept, flagged_images
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
|
||||
def check_inputs(self, prompt, height, width, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def perform_safety_guidance(
|
||||
self,
|
||||
enable_safety_guidance,
|
||||
safety_momentum,
|
||||
noise_guidance,
|
||||
noise_pred_out,
|
||||
i,
|
||||
sld_guidance_scale,
|
||||
sld_warmup_steps,
|
||||
sld_threshold,
|
||||
sld_momentum_scale,
|
||||
sld_mom_beta,
|
||||
):
|
||||
# Perform SLD guidance
|
||||
if enable_safety_guidance:
|
||||
if safety_momentum is None:
|
||||
safety_momentum = torch.zeros_like(noise_guidance)
|
||||
noise_pred_text, noise_pred_uncond = noise_pred_out[0], noise_pred_out[1]
|
||||
noise_pred_safety_concept = noise_pred_out[2]
|
||||
|
||||
# Equation 6
|
||||
scale = torch.clamp(torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0)
|
||||
|
||||
# Equation 6
|
||||
safety_concept_scale = torch.where(
|
||||
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale
|
||||
)
|
||||
|
||||
# Equation 4
|
||||
noise_guidance_safety = torch.mul((noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale)
|
||||
|
||||
# Equation 7
|
||||
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
|
||||
|
||||
# Equation 8
|
||||
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
|
||||
|
||||
if i >= sld_warmup_steps: # Warmup
|
||||
# Equation 3
|
||||
noise_guidance = noise_guidance - noise_guidance_safety
|
||||
return noise_guidance, safety_momentum
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
sld_guidance_scale: Optional[float] = 1000,
|
||||
sld_warmup_steps: Optional[int] = 10,
|
||||
sld_threshold: Optional[float] = 0.01,
|
||||
sld_momentum_scale: Optional[float] = 0.3,
|
||||
sld_mom_beta: Optional[float] = 0.4,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
sld_guidance_scale (`float`, *optional*, defaults to 1000):
|
||||
Safe latent guidance as defined in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
|
||||
`sld_guidance_scale` is defined as sS of Eq. 6. If set to be less than 1, safety guidance will be
|
||||
disabled.
|
||||
sld_warmup_steps (`int`, *optional*, defaults to 10):
|
||||
Number of warmup steps for safety guidance. SLD will only be applied for diffusion steps greater than
|
||||
`sld_warmup_steps`. `sld_warmup_steps` is defined as `delta` of [Safe Latent
|
||||
Diffusion](https://arxiv.org/abs/2211.05105).
|
||||
sld_threshold (`float`, *optional*, defaults to 0.01):
|
||||
Threshold that separates the hyperplane between appropriate and inappropriate images. `sld_threshold`
|
||||
is defined as `lamda` of Eq. 5 in [Safe Latent Diffusion](https://arxiv.org/abs/2211.05105).
|
||||
sld_momentum_scale (`float`, *optional*, defaults to 0.3):
|
||||
Scale of the SLD momentum to be added to the safety guidance at each diffusion step. If set to 0.0
|
||||
momentum will be disabled. Momentum is already built up during warmup, i.e. for diffusion steps smaller
|
||||
than `sld_warmup_steps`. `sld_momentum_scale` is defined as `sm` of Eq. 7 in [Safe Latent
|
||||
Diffusion](https://arxiv.org/abs/2211.05105).
|
||||
sld_mom_beta (`float`, *optional*, defaults to 0.4):
|
||||
Defines how safety guidance momentum builds up. `sld_mom_beta` indicates how much of the previous
|
||||
momentum will be kept. Momentum is already built up during warmup, i.e. for diffusion steps smaller
|
||||
than `sld_warmup_steps`. `sld_mom_beta` is defined as `beta m` of Eq. 8 in [Safe Latent
|
||||
Diffusion](https://arxiv.org/abs/2211.05105).
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
enable_safety_guidance = sld_guidance_scale > 1.0 and do_classifier_free_guidance
|
||||
if not enable_safety_guidance:
|
||||
warnings.warn("Safety checker disabled!")
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, enable_safety_guidance
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs.
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
safety_momentum = None
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * (3 if enable_safety_guidance else 2)) if do_classifier_free_guidance else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_out = noise_pred.chunk((3 if enable_safety_guidance else 2))
|
||||
noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
|
||||
|
||||
# default classifier free guidance
|
||||
noise_guidance = noise_pred_text - noise_pred_uncond
|
||||
|
||||
# Perform SLD guidance
|
||||
if enable_safety_guidance:
|
||||
if safety_momentum is None:
|
||||
safety_momentum = torch.zeros_like(noise_guidance)
|
||||
noise_pred_safety_concept = noise_pred_out[2]
|
||||
|
||||
# Equation 6
|
||||
scale = torch.clamp(
|
||||
torch.abs((noise_pred_text - noise_pred_safety_concept)) * sld_guidance_scale, max=1.0
|
||||
)
|
||||
|
||||
# Equation 6
|
||||
safety_concept_scale = torch.where(
|
||||
(noise_pred_text - noise_pred_safety_concept) >= sld_threshold, torch.zeros_like(scale), scale
|
||||
)
|
||||
|
||||
# Equation 4
|
||||
noise_guidance_safety = torch.mul(
|
||||
(noise_pred_safety_concept - noise_pred_uncond), safety_concept_scale
|
||||
)
|
||||
|
||||
# Equation 7
|
||||
noise_guidance_safety = noise_guidance_safety + sld_momentum_scale * safety_momentum
|
||||
|
||||
# Equation 8
|
||||
safety_momentum = sld_mom_beta * safety_momentum + (1 - sld_mom_beta) * noise_guidance_safety
|
||||
|
||||
if i >= sld_warmup_steps: # Warmup
|
||||
# Equation 3
|
||||
noise_guidance = noise_guidance - noise_guidance_safety
|
||||
|
||||
noise_pred = noise_pred_uncond + guidance_scale * noise_guidance
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept, flagged_images = self.run_safety_checker(
|
||||
image, device, text_embeddings.dtype, enable_safety_guidance
|
||||
)
|
||||
|
||||
# 10. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
if flagged_images is not None:
|
||||
flagged_images = self.numpy_to_pil(flagged_images)
|
||||
|
||||
if not return_dict:
|
||||
return (
|
||||
image,
|
||||
has_nsfw_concept,
|
||||
self._safety_text_concept if enable_safety_guidance else None,
|
||||
flagged_images,
|
||||
)
|
||||
|
||||
return StableDiffusionSafePipelineOutput(
|
||||
images=image,
|
||||
nsfw_content_detected=has_nsfw_concept,
|
||||
applied_safety_concept=self._safety_text_concept if enable_safety_guidance else None,
|
||||
unsafe_images=flagged_images,
|
||||
)
|
||||
@@ -0,0 +1,110 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
|
||||
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def cosine_distance(image_embeds, text_embeds):
|
||||
normalized_image_embeds = nn.functional.normalize(image_embeds)
|
||||
normalized_text_embeds = nn.functional.normalize(text_embeds)
|
||||
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
|
||||
|
||||
|
||||
class SafeStableDiffusionSafetyChecker(PreTrainedModel):
|
||||
config_class = CLIPConfig
|
||||
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.vision_model = CLIPVisionModel(config.vision_config)
|
||||
self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
|
||||
|
||||
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
|
||||
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
|
||||
|
||||
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
|
||||
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, clip_input, images):
|
||||
pooled_output = self.vision_model(clip_input)[1] # pooled_output
|
||||
image_embeds = self.visual_projection(pooled_output)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy()
|
||||
cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
|
||||
|
||||
result = []
|
||||
batch_size = image_embeds.shape[0]
|
||||
for i in range(batch_size):
|
||||
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
|
||||
|
||||
# increase this value to create a stronger `nfsw` filter
|
||||
# at the cost of increasing the possibility of filtering benign images
|
||||
adjustment = 0.0
|
||||
|
||||
for concept_idx in range(len(special_cos_dist[0])):
|
||||
concept_cos = special_cos_dist[i][concept_idx]
|
||||
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
|
||||
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["special_scores"][concept_idx] > 0:
|
||||
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
|
||||
adjustment = 0.01
|
||||
|
||||
for concept_idx in range(len(cos_dist[0])):
|
||||
concept_cos = cos_dist[i][concept_idx]
|
||||
concept_threshold = self.concept_embeds_weights[concept_idx].item()
|
||||
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["concept_scores"][concept_idx] > 0:
|
||||
result_img["bad_concepts"].append(concept_idx)
|
||||
|
||||
result.append(result_img)
|
||||
|
||||
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
|
||||
pooled_output = self.vision_model(clip_input)[1] # pooled_output
|
||||
image_embeds = self.visual_projection(pooled_output)
|
||||
|
||||
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
|
||||
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
|
||||
|
||||
# increase this value to create a stronger `nsfw` filter
|
||||
# at the cost of increasing the possibility of filtering benign images
|
||||
adjustment = 0.0
|
||||
|
||||
special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
|
||||
# special_scores = special_scores.round(decimals=3)
|
||||
special_care = torch.any(special_scores > 0, dim=1)
|
||||
special_adjustment = special_care * 0.01
|
||||
special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
|
||||
|
||||
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
|
||||
# concept_scores = concept_scores.round(decimals=3)
|
||||
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
|
||||
|
||||
return images, has_nsfw_concepts
|
||||
@@ -1,11 +1,16 @@
|
||||
from ...utils import is_torch_available, is_transformers_available
|
||||
from ...utils import is_torch_available, is_transformers_available, is_transformers_version
|
||||
|
||||
|
||||
if is_transformers_available() and is_torch_available():
|
||||
from .modeling_gpt2_optimus import GPT2OptimusForLatentConnector
|
||||
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
|
||||
from .modeling_text_unet import UNetFlatConditionModel
|
||||
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
|
||||
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
|
||||
from .pipeline_versatile_diffusion_image_to_text import VersatileDiffusionImageToTextPipeline
|
||||
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
|
||||
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
|
||||
else:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
)
|
||||
|
||||
@@ -1,345 +0,0 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2PreTrainedModel
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
|
||||
class GPT2OptimusAttention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, config, scale=False):
|
||||
super().__init__()
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % config.n_head == 0
|
||||
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
|
||||
self.c_attn = Conv1D(n_state * 3, nx)
|
||||
self.c_proj = Conv1D(n_state, nx)
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def _attn(self, q, k, v, attention_mask=None, head_mask=None):
|
||||
w = torch.matmul(q, k)
|
||||
if self.scale:
|
||||
w = w / math.sqrt(v.size(-1))
|
||||
nd, ns = w.size(-2), w.size(-1)
|
||||
b = self.bias[:, :, ns - nd : ns, :ns]
|
||||
w = w * b - 1e4 * (1 - b)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
w = w + attention_mask
|
||||
|
||||
w = nn.Softmax(dim=-1)(w)
|
||||
w = self.attn_dropout(w)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
w = w * head_mask
|
||||
|
||||
outputs = [torch.matmul(w, v)]
|
||||
if self.output_attentions:
|
||||
outputs.append(w)
|
||||
return outputs
|
||||
|
||||
def merge_heads(self, x):
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
||||
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
||||
|
||||
def split_heads(self, x, k=False):
|
||||
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
||||
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
||||
if k:
|
||||
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
|
||||
else:
|
||||
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
|
||||
x = self.c_attn(x)
|
||||
query, key, value = x.split(self.split_size, dim=2)
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0], layer_past[1] # transpose back cf below
|
||||
|
||||
past_key = self.split_heads(past_key, k=True)
|
||||
past_value = self.split_heads(past_value)
|
||||
# pdb.set_trace()
|
||||
key = torch.cat((past_key, key), dim=-1)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||
|
||||
attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
|
||||
a = attn_outputs[0]
|
||||
|
||||
a = self.merge_heads(a)
|
||||
a = self.c_proj(a)
|
||||
a = self.resid_dropout(a)
|
||||
|
||||
outputs = [a, present] + attn_outputs[1:]
|
||||
return outputs # a, present, (attentions)
|
||||
|
||||
|
||||
class GPT2OptimusBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
nx = config.n_embd
|
||||
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2OptimusAttention(nx, config.n_ctx, config, scale=True)
|
||||
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(4 * nx, config)
|
||||
|
||||
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
|
||||
output_attn = self.attn(
|
||||
self.ln_1(x), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
|
||||
)
|
||||
a = output_attn[0] # output_attn: a, present, (attentions)
|
||||
|
||||
x = x + a
|
||||
m = self.mlp(self.ln_2(x))
|
||||
x = x + m
|
||||
|
||||
outputs = [x] + output_attn[1:]
|
||||
return outputs # x, present, (attentions)
|
||||
|
||||
|
||||
class GPT2OptimusModel(GPT2PreTrainedModel):
|
||||
def __init__(self, config, latent_as_gpt_emb, latent_as_gpt_memory, latent_size):
|
||||
super().__init__(config)
|
||||
self.latent_as_gpt_emb = latent_as_gpt_emb
|
||||
self.latent_as_gpt_memory = latent_as_gpt_memory
|
||||
self.latent_size = latent_size
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
self.h = nn.ModuleList([GPT2OptimusBlock(config) for i in range(config.n_layer)])
|
||||
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.linear = nn.Linear(
|
||||
self.latent_size, config.hidden_size * config.n_layer, bias=False
|
||||
) # different latent vector for each layer
|
||||
self.linear_emb = nn.Linear(
|
||||
self.latent_size, config.hidden_size, bias=False
|
||||
) # share the same latent vector as the embeddings
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
past=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
):
|
||||
if past is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
else:
|
||||
if self.latent_as_gpt_emb:
|
||||
past_emb = self.linear_emb(past) # used as embeddings to add on other three embeddings
|
||||
|
||||
if self.latent_as_gpt_memory:
|
||||
past = self.linear(past)
|
||||
|
||||
# different latent vectors for each layer
|
||||
past_split = torch.split(past.unsqueeze(1), self.config.hidden_size, dim=2)
|
||||
past = list(zip(past_split, past_split))
|
||||
past_length = 1 # past[0][0].size(-2)
|
||||
else:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = (
|
||||
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
||||
) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(
|
||||
dtype=next(self.parameters()).dtype
|
||||
) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.n_layer
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1))
|
||||
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
|
||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||
if self.latent_as_gpt_emb:
|
||||
hidden_states = hidden_states + past_emb.unsqueeze(1)
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
presents = ()
|
||||
all_attentions = []
|
||||
all_hidden_states = ()
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
|
||||
|
||||
outputs = block(
|
||||
hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
|
||||
)
|
||||
|
||||
hidden_states, present = outputs[:2]
|
||||
presents = presents + (present,)
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions.append(outputs[2])
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(*output_shape)
|
||||
# Add last hidden state
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states, presents)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
|
||||
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
|
||||
outputs = outputs + (all_attentions,)
|
||||
|
||||
return outputs # last hidden state, presents, (all hidden_states), (attentions)
|
||||
|
||||
|
||||
class GPT2OptimusForLatentConnector(GPT2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.latent_as_gpt_emb = True
|
||||
self.latent_as_gpt_memory = True
|
||||
self.latent_size = getattr(config, "latent_size", 32)
|
||||
self.transformer = GPT2OptimusModel(
|
||||
config,
|
||||
latent_as_gpt_emb=self.latent_as_gpt_emb,
|
||||
latent_as_gpt_memory=self.latent_as_gpt_memory,
|
||||
latent_size=self.latent_size,
|
||||
)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.init_weights()
|
||||
self.tie_weights()
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
self.tie_weights()
|
||||
|
||||
def _tie_or_clone_weights(self, first_module, second_module):
|
||||
"""Tie or clone module weights depending of weither we are using TorchScript or not"""
|
||||
if self.config.torchscript:
|
||||
first_module.weight = nn.Parameter(second_module.weight.clone())
|
||||
else:
|
||||
first_module.weight = second_module.weight
|
||||
|
||||
if hasattr(first_module, "bias") and first_module.bias is not None:
|
||||
first_module.bias.data = torch.nn.functional.pad(
|
||||
first_module.bias.data,
|
||||
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
|
||||
def tie_weights(self):
|
||||
"""Make sure we are sharing the input and output embeddings.
|
||||
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
|
||||
"""
|
||||
self._tie_or_clone_weights(self.lm_head, self.transformer.wte)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=True,
|
||||
):
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=None,
|
||||
logits=lm_logits,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=None,
|
||||
attentions=None,
|
||||
cross_attentions=None,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...modeling_utils import ModelMixin
|
||||
from ...models.attention import Transformer2DModel
|
||||
from ...models.attention import DualTransformer2DModel, Transformer2DModel
|
||||
from ...models.embeddings import TimestepEmbedding, Timesteps
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import logging
|
||||
@@ -28,6 +28,9 @@ def get_down_block(
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlockFlat":
|
||||
@@ -57,6 +60,9 @@ def get_down_block(
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
raise ValueError(f"{down_block_type} is not supported.")
|
||||
|
||||
@@ -74,6 +80,9 @@ def get_up_block(
|
||||
attn_num_head_channels,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlockFlat":
|
||||
@@ -103,6 +112,9 @@ def get_up_block(
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} is not supported.")
|
||||
|
||||
@@ -117,11 +129,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optional*): The size of the input sample.
|
||||
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
||||
Height and width of input/output sample.
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
|
||||
@@ -163,6 +176,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
"CrossAttnUpBlockFlat",
|
||||
"CrossAttnUpBlockFlat",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
@@ -171,7 +185,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
attention_head_dim: int = 8,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -187,10 +204,20 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
@@ -209,8 +236,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -223,8 +253,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
@@ -232,6 +264,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
@@ -259,7 +293,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -270,15 +307,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
|
||||
head_dims = self.config.attention_head_dim
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a divisor of "
|
||||
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > self.config.attention_head_dim:
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"Chunk_size {slice_size} has to be smaller or equal to "
|
||||
f"the number of heads used in cross_attention {self.config.attention_head_dim}"
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for block in self.down_blocks:
|
||||
@@ -311,14 +350,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
(batch_size, sequence_length, hidden_size) encoder hidden states
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -364,6 +403,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.config.num_class_embeds is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
@@ -632,6 +677,9 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -656,16 +704,30 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
@@ -683,15 +745,17 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a divisor of "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"Chunk_size {slice_size} has to be smaller or equal to "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
@@ -830,6 +894,9 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -856,16 +923,30 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
@@ -877,15 +958,17 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a divisor of "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"Chunk_size {slice_size} has to be smaller or equal to "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
@@ -954,7 +1037,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
**kwargs,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -980,16 +1064,29 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
@@ -1009,15 +1106,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a divisor of "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > self.attn_num_head_channels:
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"Chunk_size {slice_size} has to be smaller or equal to "
|
||||
f"the number of heads used in cross_attention {self.attn_num_head_channels}"
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -9,6 +10,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import logging
|
||||
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
|
||||
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
|
||||
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
|
||||
|
||||
@@ -76,10 +78,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
@property
|
||||
def components(self) -> Dict[str, Any]:
|
||||
return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
@@ -113,8 +112,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
def image_variation(
|
||||
self,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -127,7 +126,88 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
return VersatileDiffusionImageVariationPipeline(**self.components)(
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
|
||||
The image prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import VersatileDiffusionPipeline
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from io import BytesIO
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> # let's download an initial image
|
||||
>>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
|
||||
|
||||
>>> response = requests.get(url)
|
||||
>>> image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
>>> pipe = VersatileDiffusionPipeline.from_pretrained(
|
||||
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
>>> image = pipe.image_variation(image, generator=generator).images[0]
|
||||
>>> image.save("./car_variation.png")
|
||||
```
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
expected_components = inspect.signature(VersatileDiffusionImageVariationPipeline.__init__).parameters.keys()
|
||||
components = {name: component for name, component in self.components.items() if name in expected_components}
|
||||
return VersatileDiffusionImageVariationPipeline(**components)(
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
@@ -148,8 +228,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
def text_to_image(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -162,7 +242,80 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
return VersatileDiffusionTextToImagePipeline(**self.components)(
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import VersatileDiffusionPipeline
|
||||
>>> import torch
|
||||
|
||||
>>> pipe = VersatileDiffusionPipeline.from_pretrained(
|
||||
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
>>> image = pipe.text_to_image("an astronaut riding on a horse on mars", generator=generator).images[0]
|
||||
>>> image.save("./astronaut.png")
|
||||
```
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
expected_components = inspect.signature(VersatileDiffusionTextToImagePipeline.__init__).parameters.keys()
|
||||
components = {name: component for name, component in self.components.items() if name in expected_components}
|
||||
temp_pipeline = VersatileDiffusionTextToImagePipeline(**components)
|
||||
output = temp_pipeline(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
@@ -178,3 +331,133 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
)
|
||||
# swap the attention blocks back to the original state
|
||||
temp_pipeline._swap_unet_attention_blocks()
|
||||
|
||||
return output
|
||||
|
||||
@torch.no_grad()
|
||||
def dual_guided(
|
||||
self,
|
||||
prompt: Union[PIL.Image.Image, List[PIL.Image.Image]],
|
||||
image: Union[str, List[str]],
|
||||
text_to_image_strength: float = 0.5,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import VersatileDiffusionPipeline
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from io import BytesIO
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> # let's download an initial image
|
||||
>>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
|
||||
|
||||
>>> response = requests.get(url)
|
||||
>>> image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
>>> text = "a red car in the sun"
|
||||
|
||||
>>> pipe = VersatileDiffusionPipeline.from_pretrained(
|
||||
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
>>> text_to_image_strength = 0.75
|
||||
|
||||
>>> image = pipe.dual_guided(
|
||||
... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator
|
||||
... ).images[0]
|
||||
>>> image.save("./car_variation.png")
|
||||
```
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
expected_components = inspect.signature(VersatileDiffusionDualGuidedPipeline.__init__).parameters.keys()
|
||||
components = {name: component for name, component in self.components.items() if name in expected_components}
|
||||
temp_pipeline = VersatileDiffusionDualGuidedPipeline(**components)
|
||||
output = temp_pipeline(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
text_to_image_strength=text_to_image_strength,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
eta=eta,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
output_type=output_type,
|
||||
return_dict=return_dict,
|
||||
callback=callback,
|
||||
callback_steps=callback_steps,
|
||||
)
|
||||
temp_pipeline._revert_dual_attention()
|
||||
|
||||
return output
|
||||
|
||||
+111
-72
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -65,6 +65,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
vae: AutoencoderKL
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
|
||||
_optional_components = ["text_unet"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
@@ -87,8 +89,22 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def convert_to_dual_attention(self, mix_ratio=0.5, condition_types=("image", "text")):
|
||||
if self.text_unet is not None and (
|
||||
"dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention
|
||||
):
|
||||
# if loading from a universal checkpoint rather than a saved dual-guided pipeline
|
||||
self._convert_to_dual_attention()
|
||||
|
||||
def remove_unused_weights(self):
|
||||
self.register_modules(text_unet=None)
|
||||
|
||||
def _convert_to_dual_attention(self):
|
||||
"""
|
||||
Replace image_unet's `Transformer2DModel` blocks with `DualTransformer2DModel` that contains transformer blocks
|
||||
from both `image_unet` and `text_unet`
|
||||
"""
|
||||
for name, module in self.image_unet.named_modules():
|
||||
if isinstance(module, Transformer2DModel):
|
||||
parent_name, index = name.rsplit(".", 1)
|
||||
@@ -112,22 +128,25 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
activation_fn=config.activation_fn,
|
||||
num_embeds_ada_norm=config.num_embeds_ada_norm,
|
||||
)
|
||||
for i, type in enumerate(condition_types):
|
||||
if type == "image":
|
||||
dual_transformer.transformers[i] = image_transformer
|
||||
else:
|
||||
dual_transformer.transformers[i] = text_transformer
|
||||
dual_transformer.transformers[0] = image_transformer
|
||||
dual_transformer.transformers[1] = text_transformer
|
||||
|
||||
dual_transformer.mix_ratio = mix_ratio
|
||||
self.image_unet.get_submodule(parent_name)[index] = dual_transformer
|
||||
self.image_unet.register_to_config(dual_cross_attention=True)
|
||||
|
||||
def remove_dual_attention(self):
|
||||
def _revert_dual_attention(self):
|
||||
"""
|
||||
Revert the image_unet `DualTransformer2DModel` blocks back to `Transformer2DModel` with image_unet weights Call
|
||||
this function if you reuse `image_unet` in another pipeline, e.g. `VersatileDiffusionPipeline`
|
||||
"""
|
||||
for name, module in self.image_unet.named_modules():
|
||||
if isinstance(module, DualTransformer2DModel):
|
||||
parent_name, index = name.rsplit(".", 1)
|
||||
index = int(index)
|
||||
self.image_unet.get_submodule(parent_name)[index] = module.transformers[0]
|
||||
|
||||
self.image_unet.register_to_config(dual_cross_attention=False)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
@@ -163,9 +182,14 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
if isinstance(self.image_unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.image_unet.config.attention_head_dim)
|
||||
|
||||
self.image_unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
@@ -330,7 +354,8 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
|
||||
# get prompt text embeddings
|
||||
image_input = self.image_feature_extractor(images=prompt, return_tensors="pt")
|
||||
image_embeddings = self.image_encoder(image_input.pixel_values.to(device))
|
||||
pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype)
|
||||
image_embeddings = self.image_encoder(pixel_values)
|
||||
image_embeddings = normalize_embeddings(image_embeddings)
|
||||
|
||||
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
||||
@@ -340,9 +365,10 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_images = [np.zeros((512, 512, 3))] * batch_size
|
||||
uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
|
||||
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
|
||||
uncond_embeddings = self.image_encoder(uncond_images.pixel_values.to(device))
|
||||
pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
|
||||
uncond_embeddings = self.image_encoder(pixel_values)
|
||||
uncond_embeddings = normalize_embeddings(uncond_embeddings)
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
@@ -384,23 +410,11 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(self, first_prompt, second_prompt, height, width, callback_steps):
|
||||
if (
|
||||
not isinstance(first_prompt, str)
|
||||
and not isinstance(first_prompt, PIL.Image.Image)
|
||||
and not isinstance(first_prompt, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`first_prompt` has to be of type `str` `PIL.Image` or `list` but is {type(first_prompt)}"
|
||||
)
|
||||
if (
|
||||
not isinstance(second_prompt, str)
|
||||
and not isinstance(second_prompt, PIL.Image.Image)
|
||||
and not isinstance(second_prompt, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`second_prompt` has to be of type `str` `PIL.Image` or `list` but is {type(second_prompt)}"
|
||||
)
|
||||
def check_inputs(self, prompt, image, height, width, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, PIL.Image.Image) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` `PIL.Image` or `list` but is {type(prompt)}")
|
||||
if not isinstance(image, str) and not isinstance(image, PIL.Image.Image) and not isinstance(image, list):
|
||||
raise ValueError(f"`image` has to be of type `str` `PIL.Image` or `list` but is {type(image)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
@@ -415,7 +429,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
@@ -431,19 +445,27 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def set_mix_ratio(self, mix_ratio):
|
||||
def set_transformer_params(self, mix_ratio: float = 0.5, condition_types: Tuple = ("text", "image")):
|
||||
for name, module in self.image_unet.named_modules():
|
||||
if isinstance(module, DualTransformer2DModel):
|
||||
module.mix_ratio = mix_ratio
|
||||
|
||||
for i, type in enumerate(condition_types):
|
||||
if type == "text":
|
||||
module.condition_lengths[i] = self.text_encoder.config.max_position_embeddings
|
||||
module.transformer_index_for_condition[i] = 1 # use the second (text) transformer
|
||||
else:
|
||||
module.condition_lengths[i] = 257
|
||||
module.transformer_index_for_condition[i] = 0 # use the first (image) transformer
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
first_prompt: Union[str, List[str], PIL.Image.Image, List[PIL.Image.Image]],
|
||||
second_prompt: Union[str, List[str], PIL.Image.Image, List[PIL.Image.Image]],
|
||||
prompt_mix_ratio: float = 0.5,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
prompt: Union[PIL.Image.Image, List[PIL.Image.Image]],
|
||||
image: Union[str, List[str]],
|
||||
text_to_image_strength: float = 0.5,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
@@ -462,9 +484,9 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -503,21 +525,53 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import VersatileDiffusionDualGuidedPipeline
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from io import BytesIO
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> # let's download an initial image
|
||||
>>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
|
||||
|
||||
>>> response = requests.get(url)
|
||||
>>> image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
>>> text = "a red car in the sun"
|
||||
|
||||
>>> pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(
|
||||
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.remove_unused_weights()
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
>>> text_to_image_strength = 0.75
|
||||
|
||||
>>> image = pipe(
|
||||
... prompt=text, image=image, text_to_image_strength=text_to_image_strength, generator=generator
|
||||
... ).images[0]
|
||||
>>> image.save("./car_variation.png")
|
||||
```
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
[`~pipelines.stable_diffusion.ImagePipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
|
||||
returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(first_prompt, second_prompt, height, width, callback_steps)
|
||||
self.check_inputs(prompt, image, height, width, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
first_prompt = [first_prompt] if not isinstance(first_prompt, list) else first_prompt
|
||||
second_prompt = [second_prompt] if not isinstance(second_prompt, list) else second_prompt
|
||||
batch_size = len(first_prompt)
|
||||
prompt = [prompt] if not isinstance(prompt, list) else prompt
|
||||
image = [image] if not isinstance(image, list) else image
|
||||
batch_size = len(prompt)
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@@ -525,21 +579,10 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompts
|
||||
dual_prompt_embeddings = []
|
||||
prompt_types = []
|
||||
for prompt in [first_prompt, second_prompt]:
|
||||
if isinstance(prompt[0], str):
|
||||
embeddings = self._encode_text_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
)
|
||||
prompt_types.append("text")
|
||||
else:
|
||||
embeddings = self._encode_image_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
)
|
||||
prompt_types.append("image")
|
||||
dual_prompt_embeddings.append(embeddings)
|
||||
dual_prompt_embeddings = torch.cat(dual_prompt_embeddings, dim=1)
|
||||
text_embeddings = self._encode_text_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance)
|
||||
image_embeddings = self._encode_image_prompt(image, device, num_images_per_prompt, do_classifier_free_guidance)
|
||||
dual_prompt_embeddings = torch.cat([text_embeddings, image_embeddings], dim=1)
|
||||
prompt_types = ("text", "image")
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -562,8 +605,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Combine the attention blocks of the image and text UNets
|
||||
self.convert_to_dual_attention(prompt_mix_ratio, prompt_types)
|
||||
self.set_mix_ratio(prompt_mix_ratio)
|
||||
self.set_transformer_params(text_to_image_strength, prompt_types)
|
||||
|
||||
# 8. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
@@ -586,13 +628,10 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Return the image unet to its original state
|
||||
self.remove_dual_attention()
|
||||
|
||||
# 10. Post-processing
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 11. Convert to PIL
|
||||
# 10. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
|
||||
+47
-11
@@ -71,6 +71,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
@@ -107,9 +108,14 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
if isinstance(self.image_unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.image_unet.config.attention_head_dim)
|
||||
|
||||
self.image_unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
@@ -186,7 +192,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
|
||||
# get prompt text embeddings
|
||||
image_input = self.image_feature_extractor(images=prompt, return_tensors="pt")
|
||||
image_embeddings = self.image_encoder(image_input.pixel_values.to(device))
|
||||
pixel_values = image_input.pixel_values.to(device).to(self.image_encoder.dtype)
|
||||
image_embeddings = self.image_encoder(pixel_values)
|
||||
image_embeddings = normalize_embeddings(image_embeddings)
|
||||
|
||||
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
||||
@@ -198,7 +205,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_images: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_images = [np.zeros((512, 512, 3))] * batch_size
|
||||
uncond_images = [np.zeros((512, 512, 3)) + 0.5] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -216,7 +223,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
uncond_images = negative_prompt
|
||||
|
||||
uncond_images = self.image_feature_extractor(images=uncond_images, return_tensors="pt")
|
||||
uncond_embeddings = self.image_encoder(uncond_images.pixel_values.to(device))
|
||||
pixel_values = uncond_images.pixel_values.to(device).to(self.image_encoder.dtype)
|
||||
uncond_embeddings = self.image_encoder(pixel_values)
|
||||
uncond_embeddings = normalize_embeddings(uncond_embeddings)
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
@@ -275,7 +283,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
@@ -295,8 +303,8 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -316,9 +324,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `List[PIL.Image.Image]` or `torch.Tensor`):
|
||||
The image prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -357,6 +365,31 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import VersatileDiffusionImageVariationPipeline
|
||||
>>> import torch
|
||||
>>> import requests
|
||||
>>> from io import BytesIO
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> # let's download an initial image
|
||||
>>> url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"
|
||||
|
||||
>>> response = requests.get(url)
|
||||
>>> image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
>>> pipe = VersatileDiffusionImageVariationPipeline.from_pretrained(
|
||||
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
>>> image = pipe(image, generator=generator).images[0]
|
||||
>>> image.save("./car_variation.png")
|
||||
```
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
@@ -364,6 +397,9 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(image, height, width, callback_steps)
|
||||
|
||||
+49
-18
@@ -57,6 +57,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
vae: AutoencoderKL
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
|
||||
_optional_components = ["text_unet"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
@@ -75,8 +77,15 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def swap_unet_attention_blocks(self):
|
||||
if self.text_unet is not None:
|
||||
self._swap_unet_attention_blocks()
|
||||
|
||||
def _swap_unet_attention_blocks(self):
|
||||
"""
|
||||
Swap the `Transformer2DModel` blocks between the image and text UNets
|
||||
"""
|
||||
for name, module in self.image_unet.named_modules():
|
||||
if isinstance(module, Transformer2DModel):
|
||||
parent_name, index = name.rsplit(".", 1)
|
||||
@@ -86,6 +95,9 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
self.image_unet.get_submodule(parent_name)[index],
|
||||
)
|
||||
|
||||
def remove_unused_weights(self):
|
||||
self.register_modules(text_unet=None)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
@@ -121,9 +133,14 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
if isinstance(self.image_unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.image_unet.config.attention_head_dim)
|
||||
|
||||
self.image_unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
@@ -328,7 +345,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // 8, width // 8)
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
@@ -348,8 +365,8 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -369,9 +386,9 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
@@ -410,6 +427,23 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import VersatileDiffusionTextToImagePipeline
|
||||
>>> import torch
|
||||
|
||||
>>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained(
|
||||
... "shi-labs/versatile-diffusion", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.remove_unused_weights()
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
>>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0]
|
||||
>>> image.save("./astronaut.png")
|
||||
```
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
@@ -417,6 +451,9 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.image_unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.image_unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, callback_steps)
|
||||
@@ -454,10 +491,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
# 6. Prepare extra step kwargs.
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Swap the attention blocks between the image and text UNets
|
||||
self.swap_unet_attention_blocks()
|
||||
|
||||
# 8. Denoising loop
|
||||
# 7. Denoising loop
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@@ -478,13 +512,10 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Swap the attention blocks backs in case the UNets are reused in another pipeline
|
||||
self.swap_unet_attention_blocks()
|
||||
|
||||
# 10. Post-processing
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 11. Convert to PIL
|
||||
# 10. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@@ -106,10 +106,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -122,7 +126,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -258,7 +272,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
# predict V
|
||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction`"
|
||||
)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
|
||||
@@ -23,6 +23,7 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
@@ -108,9 +109,14 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
@@ -125,7 +131,17 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
beta_schedule: str = "linear",
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
@@ -259,7 +275,19 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
# predict V
|
||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction`"
|
||||
)
|
||||
|
||||
# 4. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
|
||||
@@ -99,12 +99,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
|
||||
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -116,8 +117,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
predict_epsilon: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -241,13 +251,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.config)
|
||||
new_config["predict_epsilon"] = predict_epsilon
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
self._internal_dict = FrozenDict(new_config)
|
||||
|
||||
t = timestep
|
||||
@@ -265,10 +275,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
|
||||
" for the DDPMScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
|
||||
@@ -103,12 +103,13 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
|
||||
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
|
||||
`v-prediction` is not supported for this scheduler.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
@@ -124,8 +125,17 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
predict_epsilon: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -204,7 +214,6 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
key: random.KeyArray,
|
||||
predict_epsilon: bool = True,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
|
||||
@@ -227,13 +236,13 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.config)
|
||||
new_config["predict_epsilon"] = predict_epsilon
|
||||
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
|
||||
self._internal_dict = FrozenDict(new_config)
|
||||
|
||||
t = timestep
|
||||
@@ -251,10 +260,15 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
|
||||
" for the FlaxDDPMScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
|
||||
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -87,10 +87,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
solver_order (`int`, default `2`):
|
||||
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
predict_epsilon (`bool`, default `True`):
|
||||
we currently support both the noise prediction model and the data prediction model. If the model predicts
|
||||
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
|
||||
`predict_epsilon` to `False`.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
|
||||
or `v-prediction`.
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
||||
@@ -118,6 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -128,14 +128,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
solver_order: int = 2,
|
||||
predict_epsilon: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" DPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -203,7 +212,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
|
||||
|
||||
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
|
||||
DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
|
||||
discretize an integral of the data prediction model. So we need to first convert the model output to the
|
||||
corresponding type to match the algorithm.
|
||||
|
||||
@@ -221,11 +230,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction` for the DPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = torch.quantile(
|
||||
@@ -239,12 +257,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
return model_output
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction` for the DPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self,
|
||||
|
||||
@@ -23,6 +23,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
FlaxSchedulerMixin,
|
||||
@@ -118,10 +119,9 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
solver_order (`int`, default `2`):
|
||||
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
predict_epsilon (`bool`, default `True`):
|
||||
we currently support both the noise prediction model and the data prediction model. If the model predicts
|
||||
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
|
||||
`predict_epsilon` to `False`.
|
||||
prediction_type (`str`, default `epsilon`):
|
||||
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
|
||||
or `v-prediction`.
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
||||
@@ -149,6 +149,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
@property
|
||||
def has_state(self):
|
||||
@@ -163,14 +164,23 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[jnp.ndarray] = None,
|
||||
solver_order: int = 2,
|
||||
predict_epsilon: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
|
||||
" FlaxDPMSolverMultistepScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None:
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -242,7 +252,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
|
||||
|
||||
DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
|
||||
DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
|
||||
discretize an integral of the data prediction model. So we need to first convert the model output to the
|
||||
corresponding type to match the algorithm.
|
||||
|
||||
@@ -260,11 +270,20 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
||||
" or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = jnp.percentile(
|
||||
@@ -277,12 +296,21 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
return model_output
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = (sample - alpha_t * model_output) / sigma_t
|
||||
return epsilon
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
epsilon = alpha_t * model_output + sigma_t * sample
|
||||
return epsilon
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
|
||||
" or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
|
||||
)
|
||||
|
||||
def dpm_solver_first_order_update(
|
||||
self, model_output: jnp.ndarray, timestep: int, prev_timestep: int, sample: jnp.ndarray
|
||||
|
||||
@@ -189,7 +189,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
@@ -78,6 +78,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
@@ -198,7 +199,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
@@ -229,7 +230,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
|
||||
@@ -33,6 +33,7 @@ from .import_utils import (
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
is_unidecode_available,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn
|
||||
|
||||
if warning is not None:
|
||||
warning = warning + " " if standard_warn else ""
|
||||
warnings.warn(warning + message, DeprecationWarning)
|
||||
warnings.warn(warning + message, FutureWarning)
|
||||
|
||||
if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
|
||||
call_frame = inspect.getouterframes(inspect.currentframe())[1]
|
||||
|
||||
@@ -64,6 +64,21 @@ class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionImageVariationPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -124,6 +139,51 @@ class StableDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionPipelineSafe(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionUpscalePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class VersatileDiffusionImageVariationPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -303,6 +303,17 @@ def requires_backends(obj, backends):
|
||||
if failed:
|
||||
raise ImportError("".join(failed))
|
||||
|
||||
if name in [
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
"VersatileDiffusionPipeline",
|
||||
"VersatileDiffusionDualGuidedPipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
] and is_transformers_version("<", "4.25.0.dev0"):
|
||||
raise ImportError(
|
||||
f"You need to install `transformers` from 'main' in order to use {name}: \n```\n pip install"
|
||||
" git+https://github.com/huggingface/transformers \n```"
|
||||
)
|
||||
|
||||
|
||||
class DummyObject(type):
|
||||
"""
|
||||
@@ -347,3 +358,17 @@ def is_torch_version(operation: str, version: str):
|
||||
A string version of PyTorch
|
||||
"""
|
||||
return compare_versions(parse(_torch_version), operation, version)
|
||||
|
||||
|
||||
def is_transformers_version(operation: str, version: str):
|
||||
"""
|
||||
Args:
|
||||
Compares the current Transformers version to a given reference with an operation.
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A string version of PyTorch
|
||||
"""
|
||||
if not _transformers_available:
|
||||
return False
|
||||
return compare_versions(parse(_transformers_version), operation, version)
|
||||
|
||||
@@ -296,6 +296,44 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
def test_model_with_attention_head_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_use_linear_projection(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["use_linear_projection"] = True
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
|
||||
@@ -171,9 +171,9 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[0.49249017, 0.46064827, 0.4790093, 0.50883967, 0.4811985, 0.51540506, 0.5084924, 0.4860553, 0.47318557]
|
||||
[0.5748162, 0.60447145, 0.48821217, 0.50100636, 0.5431185, 0.45763683, 0.49657696, 0.48132733, 0.47573093]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
@@ -220,9 +220,9 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[0.4786532, 0.45791715, 0.47507674, 0.50763345, 0.48375353, 0.515062, 0.51244247, 0.48673993, 0.47105807]
|
||||
[0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
@@ -259,7 +259,7 @@ class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -68,7 +68,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_inference_predict_epsilon(self):
|
||||
def test_inference_deprecated_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.10.0", "remove")
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDPMScheduler(predict_epsilon=False)
|
||||
@@ -98,6 +98,35 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
|
||||
|
||||
def test_inference_predict_sample(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDPMScheduler(prediction_type="sample")
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = generator.manual_seed(0)
|
||||
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_eps_slice = image_eps[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - image_eps_slice.flatten()).max() < tolerance
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -111,8 +111,8 @@ class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897])
|
||||
assert image.shape == (1, 16, 16, 3)
|
||||
expected_slice = np.array([0.6806, 0.5454, 0.5638, 0.4893, 0.4656, 0.4257, 0.6248, 0.5217, 0.5498])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@@ -87,6 +87,27 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_inference_superresolution_fp16(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDIMScheduler()
|
||||
vqvae = self.dummy_vq_model
|
||||
|
||||
# put models in fp16
|
||||
unet = unet.half()
|
||||
vqvae = vqvae.half()
|
||||
|
||||
ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
init_image = self.dummy_image.to(torch_device)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
@@ -209,8 +209,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.5643956661224365,
|
||||
0.6017904281616211,
|
||||
0.4799129366874695,
|
||||
0.5267305374145508,
|
||||
0.5584856271743774,
|
||||
0.46413588523864746,
|
||||
0.5159522294998169,
|
||||
0.4963662028312683,
|
||||
0.47919973731040955,
|
||||
]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
@@ -250,8 +262,8 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
height=536,
|
||||
width=536,
|
||||
height=136,
|
||||
width=136,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
@@ -259,8 +271,8 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 134, 134, 3)
|
||||
expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557])
|
||||
assert image.shape == (1, 136, 136, 3)
|
||||
expected_slice = np.array([0.5524, 0.5626, 0.6069, 0.4727, 0.386, 0.3995, 0.4613, 0.4328, 0.4269])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -304,8 +316,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738])
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.5094760060310364,
|
||||
0.5674174427986145,
|
||||
0.46675148606300354,
|
||||
0.5125715136528015,
|
||||
0.5696930289268494,
|
||||
0.4674668312072754,
|
||||
0.5277683734893799,
|
||||
0.4964486062526703,
|
||||
0.494540274143219,
|
||||
]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -370,8 +394,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.47082293033599854,
|
||||
0.5371589064598083,
|
||||
0.4562119245529175,
|
||||
0.5220914483070374,
|
||||
0.5733777284622192,
|
||||
0.4795039892196655,
|
||||
0.5465868711471558,
|
||||
0.5074326395988464,
|
||||
0.5042197108268738,
|
||||
]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -415,8 +451,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.4707113206386566,
|
||||
0.5372191071510315,
|
||||
0.4563021957874298,
|
||||
0.5220003724098206,
|
||||
0.5734264850616455,
|
||||
0.4794946610927582,
|
||||
0.5463782548904419,
|
||||
0.5074145197868347,
|
||||
0.504422664642334,
|
||||
]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -460,8 +508,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.47082313895225525,
|
||||
0.5371587872505188,
|
||||
0.4562119245529175,
|
||||
0.5220913887023926,
|
||||
0.5733776688575745,
|
||||
0.47950395941734314,
|
||||
0.546586811542511,
|
||||
0.5074326992034912,
|
||||
0.5042197108268738,
|
||||
]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -533,8 +593,20 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.4851, 0.4617, 0.4765, 0.5127, 0.4845, 0.5153, 0.5141, 0.4886, 0.4719])
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.5108221173286438,
|
||||
0.5688379406929016,
|
||||
0.4685141146183014,
|
||||
0.5098261833190918,
|
||||
0.5657756328582764,
|
||||
0.4631010890007019,
|
||||
0.5226285457611084,
|
||||
0.49129390716552734,
|
||||
0.4899061322212219,
|
||||
]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_num_images_per_prompt(self):
|
||||
@@ -563,13 +635,13 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# test num_images_per_prompt=1 (default)
|
||||
images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert images.shape == (1, 128, 128, 3)
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert images.shape == (batch_size, 128, 128, 3)
|
||||
assert images.shape == (batch_size, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for single prompt
|
||||
num_images_per_prompt = 2
|
||||
@@ -577,7 +649,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
|
||||
).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 128, 128, 3)
|
||||
assert images.shape == (num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
@@ -585,7 +657,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
[prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
|
||||
).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 128, 128, 3)
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_fp16(self):
|
||||
@@ -618,7 +690,7 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
def test_stable_diffusion_long_prompt(self):
|
||||
unet = self.dummy_cond_unet
|
||||
@@ -671,6 +743,43 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
assert cap_logger.out.count("@") == 25
|
||||
assert cap_logger_3.out == ""
|
||||
|
||||
def test_stable_diffusion_height_width_opt(self):
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "hey"
|
||||
|
||||
output = sd_pipe(prompt, number_of_steps=1, output_type="np")
|
||||
image_shape = output.images[0].shape[:2]
|
||||
assert image_shape == (64, 64)
|
||||
|
||||
output = sd_pipe(prompt, number_of_steps=1, height=96, width=96, output_type="np")
|
||||
image_shape = output.images[0].shape[:2]
|
||||
assert image_shape == (96, 96)
|
||||
|
||||
config = dict(sd_pipe.unet.config)
|
||||
config["sample_size"] = 96
|
||||
sd_pipe.unet = UNet2DConditionModel.from_config(config).to(torch_device)
|
||||
output = sd_pipe(prompt, number_of_steps=1, output_type="np")
|
||||
image_shape = output.images[0].shape[:2]
|
||||
assert image_shape == (192, 192)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@@ -839,7 +948,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[1.8285, 1.2857, -0.1024, 1.2406, -2.3068, 1.0747, -0.0818, -0.6520, -2.9506]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
|
||||
elif step == 50:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
|
||||
@@ -0,0 +1,423 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
|
||||
return image
|
||||
|
||||
@property
|
||||
def dummy_cond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_image_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPVisionConfig(
|
||||
hidden_size=32,
|
||||
projection_dim=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
image_size=32,
|
||||
patch_size=4,
|
||||
)
|
||||
return CLIPVisionModelWithProjection(config)
|
||||
|
||||
@property
|
||||
def dummy_extractor(self):
|
||||
def extract(*args, **kwargs):
|
||||
class Out:
|
||||
def __init__(self):
|
||||
self.pixel_values = torch.ones([0])
|
||||
|
||||
def to(self, device):
|
||||
self.pixel_values.to(device)
|
||||
return self
|
||||
|
||||
return Out()
|
||||
|
||||
return extract
|
||||
|
||||
def test_stable_diffusion_img_variation_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
image_encoder = self.dummy_image_encoder
|
||||
|
||||
init_image = self.dummy_image.to(device)
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionImageVariationPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
image_encoder=image_encoder,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
init_image,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
init_image,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5093, 0.5717, 0.4806, 0.4891, 0.5552, 0.4594, 0.5177, 0.4894, 0.4904])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_img_variation_multiple_images(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
image_encoder = self.dummy_image_encoder
|
||||
|
||||
init_image = self.dummy_image.to(device).repeat(2, 1, 1, 1)
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionImageVariationPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
image_encoder=image_encoder,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
init_image,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images
|
||||
|
||||
image_slice = image[-1, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (2, 64, 64, 3)
|
||||
expected_slice = np.array([0.6427, 0.5452, 0.5602, 0.5478, 0.5968, 0.6211, 0.5538, 0.5514, 0.5281])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
|
||||
device = "cpu"
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
image_encoder = self.dummy_image_encoder
|
||||
|
||||
init_image = self.dummy_image.to(device)
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionImageVariationPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
image_encoder=image_encoder,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# test num_images_per_prompt=1 (default)
|
||||
images = sd_pipe(
|
||||
init_image,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt=1 (default) for batch of images
|
||||
batch_size = 2
|
||||
images = sd_pipe(
|
||||
init_image.repeat(batch_size, 1, 1, 1),
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
assert images.shape == (batch_size, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for single prompt
|
||||
num_images_per_prompt = 2
|
||||
images = sd_pipe(
|
||||
init_image,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
images = sd_pipe(
|
||||
init_image.repeat(batch_size, 1, 1, 1),
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_img_variation_fp16(self):
|
||||
"""Test that stable diffusion img2img works with fp16"""
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
image_encoder = self.dummy_image_encoder
|
||||
|
||||
init_image = self.dummy_image.to(torch_device).float()
|
||||
|
||||
# put models in fp16
|
||||
unet = unet.half()
|
||||
vae = vae.half()
|
||||
image_encoder = image_encoder.half()
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionImageVariationPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
image_encoder=image_encoder,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe(
|
||||
init_image,
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion_img_variation_pipeline_default(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/vermeer.jpg"
|
||||
)
|
||||
init_image = init_image.resize((512, 512))
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/vermeer.npy"
|
||||
)
|
||||
|
||||
model_id = "fusing/sd-image-variations-diffusers"
|
||||
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(
|
||||
init_image,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
# img2img is flaky across GPUs even in fp32, so using MAE here
|
||||
assert np.abs(expected_image - image).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_img_variation_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
|
||||
test_callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 0:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([1.83, 1.293, -0.09705, 1.256, -2.293, 1.091, -0.0809, -0.65, -2.953])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
elif step == 37:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([2.285, 2.703, 1.969, 0.696, -1.323, 0.9253, -0.5464, -1.521, -2.537])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
||||
"fusing/sd-image-variations-diffusers",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
pipe(
|
||||
init_image,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
callback=test_callback_fn,
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 51
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
model_id = "fusing/sd-image-variations-diffusers"
|
||||
lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionImageVariationPipeline.from_pretrained(
|
||||
model_id, scheduler=lms, safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
_ = pipe(
|
||||
init_image,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
num_inference_steps=5,
|
||||
)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 2.6 GB is allocated
|
||||
assert mem_bytes < 2.6 * 10**9
|
||||
@@ -167,8 +167,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipeline(
|
||||
@@ -212,8 +212,9 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.5075, 0.4485, 0.4558, 0.5369, 0.5369, 0.5236, 0.5127, 0.4983, 0.4776])
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4723, 0.5731, 0.3939, 0.5441, 0.5922, 0.4392, 0.5059, 0.4651, 0.4474])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -226,8 +227,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipeline(
|
||||
@@ -268,8 +269,8 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((128, 128))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
|
||||
|
||||
# put models in fp16
|
||||
unet = unet.half()
|
||||
@@ -300,7 +301,7 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
mask_image=mask_image,
|
||||
).images
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -168,7 +168,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
@@ -227,7 +227,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
@@ -273,7 +273,7 @@ class StableDiffusionInpaintLegacyPipelineFastTests(PipelineTesterMixin, unittes
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
|
||||
@@ -0,0 +1,733 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
logging,
|
||||
)
|
||||
from diffusers.utils import load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_cond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4, 8, 8),
|
||||
use_linear_projection=True,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=512,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
def test_save_pretrained_from_pretrained(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
image = output.images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd_pipe.save_pretrained(tmpdirname)
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||
sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = generator.manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
new_image = output.images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
|
||||
|
||||
def test_stable_diffusion_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5649, 0.6022, 0.4804, 0.5270, 0.5585, 0.4643, 0.5159, 0.4963, 0.4793])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_pndm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5099, 0.5677, 0.4671, 0.5128, 0.5697, 0.4676, 0.5277, 0.4964, 0.4946])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_k_lms(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_k_euler_ancestral(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4715, 0.5376, 0.4569, 0.5224, 0.5734, 0.4797, 0.5465, 0.5074, 0.5046])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_k_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = EulerDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4717, 0.5376, 0.4568, 0.5225, 0.5734, 0.4797, 0.5467, 0.5074, 0.5043])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_attention_chunk(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output_1 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
|
||||
# make sure chunking the attention yields the same result
|
||||
sd_pipe.enable_attention_slicing(slice_size=1)
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output_2 = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
|
||||
assert np.abs(output_2.images.flatten() - output_1.images.flatten()).max() < 1e-4
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_fp16(self):
|
||||
"""Test that stable diffusion works with fp16"""
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# put models in fp16
|
||||
unet = unet.half()
|
||||
vae = vae.half()
|
||||
bert = bert.half()
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
def test_stable_diffusion_long_prompt(self):
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
do_classifier_free_guidance = True
|
||||
negative_prompt = None
|
||||
num_images_per_prompt = 1
|
||||
logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion")
|
||||
|
||||
prompt = 25 * "@"
|
||||
with CaptureLogger(logger) as cap_logger_3:
|
||||
text_embeddings_3 = sd_pipe._encode_prompt(
|
||||
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
prompt = 100 * "@"
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
text_embeddings = sd_pipe._encode_prompt(
|
||||
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
negative_prompt = "Hello"
|
||||
with CaptureLogger(logger) as cap_logger_2:
|
||||
text_embeddings_2 = sd_pipe._encode_prompt(
|
||||
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape
|
||||
assert text_embeddings.shape[1] == 77
|
||||
|
||||
assert cap_logger.out == cap_logger_2.out
|
||||
# 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25
|
||||
assert cap_logger.out.count("@") == 25
|
||||
assert cap_logger_3.out == ""
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusion2PipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np")
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0788, 0.0823, 0.1091, 0.1165, 0.1263, 0.1459, 0.1317, 0.1507, 0.1551])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_ddim(self):
|
||||
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="scheduler")
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=scheduler)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
output = sd_pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy")
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0642, 0.0382, 0.0408, 0.0395, 0.0227, 0.0942, 0.0749, 0.0669, 0.0248])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_k_lms(self):
|
||||
scheduler = LMSDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2-base", subfolder="scheduler")
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=scheduler)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "a photograph of an astronaut riding a horse"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=5, output_type="numpy"
|
||||
).images
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0548, 0.0626, 0.0612, 0.0611, 0.0706, 0.0586, 0.0843, 0.0333, 0.1197])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_attention_slicing(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
model_id = "stabilityai/stable-diffusion-2-base"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "a photograph of an astronaut riding a horse"
|
||||
|
||||
# make attention efficient
|
||||
pipe.enable_attention_slicing()
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
output_chunked = pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||
)
|
||||
image_chunked = output_chunked.images
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
# make sure that less than 3.75 GB is allocated
|
||||
assert mem_bytes < 3.75 * 10**9
|
||||
|
||||
# disable chunking
|
||||
pipe.disable_attention_slicing()
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
output = pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||
)
|
||||
image = output.images
|
||||
|
||||
# make sure that more than 3.75 GB is allocated
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes > 3.75 * 10**9
|
||||
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_same_quality(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
model_id = "stabilityai/stable-diffusion-2-base"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "a photograph of an astronaut riding a horse"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output_chunked = pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||
)
|
||||
image_chunked = output_chunked.images
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id)
|
||||
pipe = pipe.to(torch_device)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")
|
||||
image = output.images
|
||||
|
||||
# Make sure results are close enough
|
||||
diff = np.abs(image_chunked.flatten() - image.flatten())
|
||||
# They ARE different since ops are not run always at the same precision
|
||||
# however, they should be extremely close.
|
||||
assert diff.mean() < 5e-2
|
||||
|
||||
def test_stable_diffusion_text2img_pipeline_default(self):
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-text2img/astronaut_riding_a_horse.npy"
|
||||
)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-2-base"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "astronaut riding a horse"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 5e-3
|
||||
|
||||
def test_stable_diffusion_text2img_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
|
||||
test_callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 0:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([1.8606, 1.3169, -0.0691, 1.2374, -2.309, 1.077, -0.1084, -0.6774, -2.9594])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
|
||||
elif step == 20:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([1.0757, 1.1860, 1.1410, 0.4645, -0.2476, 0.6100, -0.7755, -0.8841, -0.9497])
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-base", revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=20,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
callback=test_callback_fn,
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 21
|
||||
|
||||
def test_stable_diffusion_low_cpu_mem_usage(self):
|
||||
pipeline_id = "stabilityai/stable-diffusion-2-base"
|
||||
|
||||
start_time = time.time()
|
||||
pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline_low_cpu_mem_usage.to(torch_device)
|
||||
low_cpu_mem_usage_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, low_cpu_mem_usage=False
|
||||
)
|
||||
normal_load_time = time.time() - start_time
|
||||
|
||||
assert 2 * low_cpu_mem_usage_time < normal_load_time
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipeline_id = "stabilityai/stable-diffusion-2-base"
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline.enable_attention_slicing(1)
|
||||
pipeline.enable_sequential_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
_ = pipeline(prompt, generator=generator, num_inference_steps=5)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 2.8 GB is allocated
|
||||
assert mem_bytes < 2.8 * 10**9
|
||||
@@ -0,0 +1,345 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
|
||||
return image
|
||||
|
||||
@property
|
||||
def dummy_cond_unet_inpaint(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=9,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4, 8, 8),
|
||||
use_linear_projection=True,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=512,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
@property
|
||||
def dummy_extractor(self):
|
||||
def extract(*args, **kwargs):
|
||||
class Out:
|
||||
def __init__(self):
|
||||
self.pixel_values = torch.ones([0])
|
||||
|
||||
def to(self, device):
|
||||
self.pixel_values.to(device)
|
||||
return self
|
||||
|
||||
return Out()
|
||||
|
||||
return extract
|
||||
|
||||
def test_stable_diffusion_inpaint(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet_inpaint
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4727, 0.5735, 0.3941, 0.5446, 0.5926, 0.4394, 0.5062, 0.4654, 0.4476])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_inpaint_fp16(self):
|
||||
"""Test that stable diffusion inpaint works with fp16"""
|
||||
unet = self.dummy_cond_unet_inpaint
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
|
||||
|
||||
# put models in fp16
|
||||
unet = unet.half()
|
||||
vae = vae.half()
|
||||
text_encoder = text_encoder.half()
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
).images
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
|
||||
# @slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion_inpaint_pipeline(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/sd2-inpaint/init_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint"
|
||||
"/yellow_cat_sitting_on_a_park_bench.npy"
|
||||
)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-2-inpainting"
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_pipeline_fp16(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/sd2-inpaint/init_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint"
|
||||
"/yellow_cat_sitting_on_a_park_bench_fp16.npy"
|
||||
)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-2-inpainting"
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
safety_checker=None,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 5e-1
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/sd2-inpaint/init_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
|
||||
)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-2-inpainting"
|
||||
pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=None,
|
||||
scheduler=pndm,
|
||||
device_map="auto",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
_ = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
generator=generator,
|
||||
num_inference_steps=5,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 2.65 GB is allocated
|
||||
assert mem_bytes < 2.65 * 10**9
|
||||
@@ -0,0 +1,315 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionUpscalePipeline, UNet2DConditionModel
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class StableDiffusionUpscalePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
|
||||
return image
|
||||
|
||||
@property
|
||||
def dummy_cond_unet_upscale(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=7,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
# SD2-specific config below
|
||||
attention_head_dim=8,
|
||||
use_linear_projection=True,
|
||||
only_cross_attention=(True, True, False),
|
||||
num_class_embeds=100,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
block_out_channels=[32, 32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=512,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
def test_stable_diffusion_upscale(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet_upscale
|
||||
low_res_scheduler = DDPMScheduler()
|
||||
scheduler = DDIMScheduler(prediction_type="v_prediction")
|
||||
vae = self.dummy_vae
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionUpscalePipeline(
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
max_noise_level=350,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
image=low_res_image,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
noise_level=20,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
image=low_res_image,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
noise_level=20,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
expected_height_width = low_res_image.size[0] * 4
|
||||
assert image.shape == (1, expected_height_width, expected_height_width, 3)
|
||||
expected_slice = np.array([0.2562, 0.3606, 0.4204, 0.4469, 0.4822, 0.4647, 0.5315, 0.5748, 0.5606])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_upscale_fp16(self):
|
||||
"""Test that stable diffusion upscale works with fp16"""
|
||||
unet = self.dummy_cond_unet_upscale
|
||||
low_res_scheduler = DDPMScheduler()
|
||||
scheduler = DDIMScheduler(prediction_type="v_prediction")
|
||||
vae = self.dummy_vae
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
|
||||
# put models in fp16, except vae as it overflows in fp16
|
||||
unet = unet.half()
|
||||
text_encoder = text_encoder.half()
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionUpscalePipeline(
|
||||
unet=unet,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
max_noise_level=350,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe(
|
||||
[prompt],
|
||||
image=low_res_image,
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
expected_height_width = low_res_image.size[0] * 4
|
||||
assert image.shape == (1, expected_height_width, expected_height_width, 3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion_upscale_pipeline(self):
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/sd2-upscale/low_res_cat.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale"
|
||||
"/upsampled_cat.npy"
|
||||
)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "a cat sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_upscale_pipeline_fp16(self):
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/sd2-upscale/low_res_cat.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale"
|
||||
"/upsampled_cat_fp16.npy"
|
||||
)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_id,
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "a cat sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 5e-1
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/sd2-upscale/low_res_cat.png"
|
||||
)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
||||
model_id,
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
prompt = "a cat sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
_ = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
generator=generator,
|
||||
num_inference_steps=5,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 2.65 GB is allocated
|
||||
assert mem_bytes < 2.65 * 10**9
|
||||
@@ -0,0 +1,474 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class StableDiffusion2VPredictionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_cond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4, 8, 8),
|
||||
use_linear_projection=True,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=64,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
def test_stable_diffusion_v_pred_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type="v_prediction",
|
||||
)
|
||||
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.6424, 0.6109, 0.494, 0.5088, 0.4984, 0.4525, 0.5059, 0.5068, 0.4474])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_v_pred_k_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="v_prediction"
|
||||
)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4616, 0.5184, 0.4887, 0.5111, 0.4839, 0.48, 0.5119, 0.5263, 0.4776])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_v_pred_fp16(self):
|
||||
"""Test that stable diffusion v-prediction works with fp16"""
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type="v_prediction",
|
||||
)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# put models in fp16
|
||||
unet = unet.half()
|
||||
vae = vae.half()
|
||||
bert = bert.half()
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion_v_pred_default(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.enable_attention_slicing()
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=20, output_type="np")
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
|
||||
assert image.shape == (1, 768, 768, 3)
|
||||
expected_slice = np.array([0.0567, 0.057, 0.0416, 0.0463, 0.0433, 0.06, 0.0517, 0.0526, 0.0866])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_v_pred_euler(self):
|
||||
scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2", subfolder="scheduler")
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.enable_attention_slicing()
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
output = sd_pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy")
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
|
||||
assert image.shape == (1, 768, 768, 3)
|
||||
expected_slice = np.array([0.0351, 0.0376, 0.0505, 0.0424, 0.0551, 0.0656, 0.0471, 0.0276, 0.0596])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_v_pred_dpm(self):
|
||||
"""
|
||||
TODO: update this test after making DPM compatible with V-prediction!
|
||||
"""
|
||||
scheduler = DPMSolverMultistepScheduler.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", subfolder="scheduler"
|
||||
)
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2", scheduler=scheduler)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.enable_attention_slicing()
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "a photograph of an astronaut riding a horse"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=5, output_type="numpy"
|
||||
).images
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
assert image.shape == (1, 768, 768, 3)
|
||||
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_attention_slicing_v_pred(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
model_id = "stabilityai/stable-diffusion-2"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "a photograph of an astronaut riding a horse"
|
||||
|
||||
# make attention efficient
|
||||
pipe.enable_attention_slicing()
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
output_chunked = pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||
)
|
||||
image_chunked = output_chunked.images
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
# make sure that less than 5.5 GB is allocated
|
||||
assert mem_bytes < 5.5 * 10**9
|
||||
|
||||
# disable slicing
|
||||
pipe.disable_attention_slicing()
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
output = pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
|
||||
)
|
||||
image = output.images
|
||||
|
||||
# make sure that more than 5.5 GB is allocated
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
assert mem_bytes > 5.5 * 10**9
|
||||
assert np.abs(image_chunked.flatten() - image.flatten()).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_text2img_pipeline_v_pred_default(self):
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/"
|
||||
"sd2-text2img/astronaut_riding_a_horse_v_pred.npy"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2")
|
||||
pipe.to(torch_device)
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "astronaut riding a horse"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (768, 768, 3)
|
||||
assert np.abs(expected_image - image).max() < 5e-3
|
||||
|
||||
def test_stable_diffusion_text2img_pipeline_v_pred_fp16(self):
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/"
|
||||
"sd2-text2img/astronaut_riding_a_horse_v_pred_fp16.npy"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "astronaut riding a horse"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(prompt=prompt, guidance_scale=7.5, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (768, 768, 3)
|
||||
assert np.abs(expected_image - image).max() < 5e-3
|
||||
|
||||
def test_stable_diffusion_text2img_intermediate_state_v_pred(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
|
||||
test_callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 0:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 96, 96)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.2543, -1.2755, 0.4261, -0.9555, -1.173, -0.5892, 2.4159, 0.1554, -1.2098]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-3
|
||||
elif step == 19:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 96, 96)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.9572, -0.967, -0.6152, 0.0894, -0.699, -0.2344, 1.5465, -0.0357, -0.1141]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2", revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast(torch_device):
|
||||
pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=20,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
callback=test_callback_fn,
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 20
|
||||
|
||||
def test_stable_diffusion_low_cpu_mem_usage_v_pred(self):
|
||||
pipeline_id = "stabilityai/stable-diffusion-2"
|
||||
|
||||
start_time = time.time()
|
||||
pipeline_low_cpu_mem_usage = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline_low_cpu_mem_usage.to(torch_device)
|
||||
low_cpu_mem_usage_time = time.time() - start_time
|
||||
|
||||
start_time = time.time()
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
pipeline_id, revision="fp16", torch_dtype=torch.float16, low_cpu_mem_usage=False
|
||||
)
|
||||
normal_load_time = time.time() - start_time
|
||||
|
||||
assert 2 * low_cpu_mem_usage_time < normal_load_time
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading_v_pred(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipeline_id = "stabilityai/stable-diffusion-2"
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline.enable_attention_slicing(1)
|
||||
pipeline.enable_sequential_cpu_offload()
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
_ = pipeline(prompt, generator=generator, num_inference_steps=5)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 2.8 GB is allocated
|
||||
assert mem_bytes < 2.8 * 10**9
|
||||
@@ -0,0 +1,435 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
|
||||
from diffusers.utils import floats_tensor, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class SafeDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
|
||||
return image
|
||||
|
||||
@property
|
||||
def dummy_cond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
@property
|
||||
def dummy_extractor(self):
|
||||
def extract(*args, **kwargs):
|
||||
class Out:
|
||||
def __init__(self):
|
||||
self.pixel_values = torch.ones([0])
|
||||
|
||||
def to(self, device):
|
||||
self.pixel_values.to(device)
|
||||
return self
|
||||
|
||||
return Out()
|
||||
|
||||
return extract
|
||||
|
||||
def test_safe_diffusion_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5644, 0.6018, 0.4799, 0.5267, 0.5585, 0.4641, 0.516, 0.4964, 0.4792])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_pndm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5095, 0.5674, 0.4668, 0.5126, 0.5697, 0.4675, 0.5278, 0.4964, 0.4945])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_no_safety_checker(self):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
|
||||
)
|
||||
assert isinstance(pipe, StableDiffusionPipeline)
|
||||
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
|
||||
assert pipe.safety_checker is None
|
||||
|
||||
image = pipe("example prompt", num_inference_steps=2).images[0]
|
||||
assert image is not None
|
||||
|
||||
# check that there's no error when saving a pipeline with one of the models being None
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
# sanity check that the pipeline still works
|
||||
assert pipe.safety_checker is None
|
||||
image = pipe("example prompt", num_inference_steps=2).images[0]
|
||||
assert image is not None
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_fp16(self):
|
||||
"""Test that stable diffusion works with fp16"""
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# put models in fp16
|
||||
unet = unet.half()
|
||||
vae = vae.half()
|
||||
bert = bert.half()
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_harm_safe_stable_diffusion(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)
|
||||
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = (
|
||||
"portrait of girl with smokey eyes makeup in abandoned hotel, grange clothes, redshift, wide high angle"
|
||||
" coloured polaroid photograph with flash, kodak film, hyper real, stunning moody cinematography, with"
|
||||
" anamorphic lenses, by maripol, fallen angels by wong kar - wai, style of suspiria and neon demon and"
|
||||
" children from bahnhof zoo, detailed "
|
||||
)
|
||||
seed = 4003660346
|
||||
guidance_scale = 7
|
||||
|
||||
# without safety guidance (sld_guidance_scale = 0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=50,
|
||||
output_type="np",
|
||||
width=512,
|
||||
height=512,
|
||||
sld_guidance_scale=0,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
# without safety guidance (strong configuration)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=50,
|
||||
output_type="np",
|
||||
width=512,
|
||||
height=512,
|
||||
sld_guidance_scale=2000,
|
||||
sld_warmup_steps=7,
|
||||
sld_threshold=0.025,
|
||||
sld_momentum_scale=0.5,
|
||||
sld_mom_beta=0.7,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_nudity_safe_stable_diffusion(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)
|
||||
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "padme amidala taking a bath artwork, safe for work, no nudity"
|
||||
seed = 2734971755
|
||||
guidance_scale = 7
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=50,
|
||||
output_type="np",
|
||||
width=512,
|
||||
height=512,
|
||||
sld_guidance_scale=0,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=50,
|
||||
output_type="np",
|
||||
width=512,
|
||||
height=512,
|
||||
sld_guidance_scale=2000,
|
||||
sld_warmup_steps=7,
|
||||
sld_threshold=0.025,
|
||||
sld_momentum_scale=0.5,
|
||||
sld_mom_beta=0.7,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_nudity_safetychecker_safe_stable_diffusion(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = (
|
||||
"the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c."
|
||||
" leyendecker"
|
||||
)
|
||||
seed = 1044355234
|
||||
guidance_scale = 12
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=50,
|
||||
output_type="np",
|
||||
width=512,
|
||||
height=512,
|
||||
sld_guidance_scale=0,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=50,
|
||||
output_type="np",
|
||||
width=512,
|
||||
height=512,
|
||||
sld_guidance_scale=2000,
|
||||
sld_warmup_steps=7,
|
||||
sld_threshold=0.025,
|
||||
sld_momentum_scale=0.5,
|
||||
sld_mom_beta=0.7,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561])
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
@@ -42,16 +42,22 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("diffusers/vd-official-test")
|
||||
def test_remove_unused_weights_save_load(self):
|
||||
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion")
|
||||
# remove text_unet
|
||||
pipe.remove_unused_weights()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
second_prompt = load_image(
|
||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||
)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe(
|
||||
first_prompt="first prompt",
|
||||
second_prompt="second prompt",
|
||||
prompt_mix_ratio=0.75,
|
||||
prompt="first prompt",
|
||||
image=second_prompt,
|
||||
text_to_image_strength=0.75,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=2,
|
||||
@@ -61,14 +67,15 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = pipe(
|
||||
first_prompt="first prompt",
|
||||
second_prompt="second prompt",
|
||||
prompt_mix_ratio=0.75,
|
||||
prompt="first prompt",
|
||||
image=second_prompt,
|
||||
text_to_image_strength=0.75,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=2,
|
||||
@@ -77,8 +84,9 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
|
||||
|
||||
def test_inference_image_variations(self):
|
||||
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("diffusers/vd-official-test")
|
||||
def test_inference_dual_guided(self):
|
||||
pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained("shi-labs/versatile-diffusion")
|
||||
pipe.remove_unused_weights()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -88,9 +96,9 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe(
|
||||
first_prompt=first_prompt,
|
||||
second_prompt=second_prompt,
|
||||
prompt_mix_ratio=0.75,
|
||||
prompt=first_prompt,
|
||||
image=second_prompt,
|
||||
text_to_image_strength=0.75,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
@@ -100,5 +108,5 @@ class VersatileDiffusionDualGuidedPipelineIntegrationTests(unittest.TestCase):
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.1811, 0.0430, 0.0433, 0.1082, 0.0144, 0.0306, 0.0683, 0.0248, 0.0876])
|
||||
expected_slice = np.array([0.014, 0.0112, 0.0136, 0.0145, 0.0107, 0.0113, 0.0272, 0.0215, 0.0216])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import VersatileDiffusionImageToTextPipeline, DDIMScheduler
|
||||
from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class VersatileDiffusionImageToTextPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class VersatileDiffusionImageToTextPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_image_to_text(self):
|
||||
pipe = VersatileDiffusionImageToTextPipeline.from_pretrained("diffusers/vd-official-test")
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image_prompt = load_image(
|
||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/boy_and_girl.jpg"
|
||||
)
|
||||
# generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
np.random.seed(8)
|
||||
torch.manual_seed(108)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
text = pipe(
|
||||
image=image_prompt,
|
||||
# generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
output_type="str",
|
||||
).text
|
||||
|
||||
assert text == "Corret me"
|
||||
@@ -35,7 +35,7 @@ class VersatileDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, uni
|
||||
@require_torch_gpu
|
||||
class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_image_variations(self):
|
||||
pipe = VersatileDiffusionImageVariationPipeline.from_pretrained("diffusers/vd-official-test")
|
||||
pipe = VersatileDiffusionImageVariationPipeline.from_pretrained("shi-labs/versatile-diffusion")
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -54,5 +54,5 @@ class VersatileDiffusionImageVariationPipelineIntegrationTests(unittest.TestCase
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.1811, 0.0430, 0.0433, 0.1082, 0.0144, 0.0306, 0.0683, 0.0248, 0.0876])
|
||||
expected_slice = np.array([0.1205, 0.1914, 0.2289, 0.0883, 0.1595, 0.1683, 0.0703, 0.1493, 0.1298])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import VersatileDiffusionPipeline
|
||||
from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class VersatileDiffusionMegaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class VersatileDiffusionMegaPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt_image = load_image(
|
||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||
)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe.dual_guided(
|
||||
prompt="first prompt",
|
||||
image=prompt_image,
|
||||
text_to_image_strength=0.75,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=2,
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe = VersatileDiffusionPipeline.from_pretrained(tmpdirname, torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = pipe.dual_guided(
|
||||
prompt="first prompt",
|
||||
image=prompt_image,
|
||||
text_to_image_strength=0.75,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=2,
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't have the same forward pass"
|
||||
|
||||
def test_inference_dual_guided_then_text_to_image(self):
|
||||
pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "cyberpunk 2077"
|
||||
init_image = load_image(
|
||||
"https://raw.githubusercontent.com/SHI-Labs/Versatile-Diffusion/master/assets/benz.jpg"
|
||||
)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe.dual_guided(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
text_to_image_strength=0.75,
|
||||
generator=generator,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0081, 0.0032, 0.0002, 0.0056, 0.0027, 0.0000, 0.0051, 0.0020, 0.0007])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger "
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe.text_to_image(
|
||||
prompt=prompt, generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
|
||||
).images
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0408, 0.0181, 0.0, 0.0388, 0.0046, 0.0461, 0.0411, 0.0, 0.0222])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
image = pipe.image_variation(init_image, generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.3403, 0.1809, 0.0938, 0.3855, 0.2393, 0.1243, 0.4028, 0.3110, 0.1799])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user