Compare commits
47 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c013139aee | |||
| d4aaee48fb | |||
| 78507bda24 | |||
| d2a5247a1f | |||
| 309d8cf9ab | |||
| b285d94e10 | |||
| 55660cfb6d | |||
| 46bef6e31d | |||
| 22a31760c4 | |||
| f0b661b8fb | |||
| 8552fd7efa | |||
| e09a7d01c8 | |||
| d3ce6f4b1e | |||
| ff91f154ee | |||
| 62bea2df36 | |||
| 9136be14a7 | |||
| 7004ff55d5 | |||
| ca7ca11bcd | |||
| c7da8fd233 | |||
| b8bfef2ab9 | |||
| f3f626d556 | |||
| b7b4683bdc | |||
| 56958e1177 | |||
| ec021923d2 | |||
| 1598a57958 | |||
| 63805f8af7 | |||
| 9920c333c6 | |||
| f38e3626cd | |||
| 5f826a35fb | |||
| f7278638e4 | |||
| b36cbd4fba | |||
| 2e3541d7f4 | |||
| 2b4f849db9 | |||
| e4c356d3f6 | |||
| 2ea1da89ab | |||
| fa6d52d594 | |||
| a72a057d62 | |||
| 2f489571a7 | |||
| e75eae3711 | |||
| 5e5ce13e2f | |||
| 7f0f7e1e91 | |||
| 3d2648d743 | |||
| 1f4deb697f | |||
| f20c8f5a1a | |||
| 5b6582cf73 | |||
| 4f0141a67d | |||
| 1021929313 |
@@ -13,6 +13,7 @@ jobs:
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: diffusers
|
||||
notebook_folder: diffusers_doc
|
||||
languages: en ko
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
|
||||
@@ -47,3 +47,4 @@ jobs:
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
make deps_table_check_updated
|
||||
|
||||
@@ -31,6 +31,11 @@ jobs:
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: Fast ONNXRuntime CPU tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-onnxruntime-cpu
|
||||
report: onnx_cpu
|
||||
- name: PyTorch Example CPU tests on Ubuntu
|
||||
framework: pytorch_examples
|
||||
runner: docker-cpu
|
||||
|
||||
@@ -29,6 +29,11 @@ jobs:
|
||||
runner: docker-tpu
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
report: flax_tpu
|
||||
- name: Slow ONNXRuntime CUDA tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
report: onnx_cuda
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
|
||||
@@ -29,6 +29,11 @@ jobs:
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: Fast ONNXRuntime CPU tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-onnxruntime-cpu
|
||||
report: onnx_cpu
|
||||
- name: PyTorch Example CPU tests on Ubuntu
|
||||
framework: pytorch_examples
|
||||
runner: docker-cpu
|
||||
|
||||
@@ -467,12 +467,12 @@ image.save("ddpm_generated_image.png")
|
||||
- [Unconditional Diffusion with continuous scheduler](https://huggingface.co/google/ncsnpp-ffhq-1024)
|
||||
|
||||
**Other Image Notebooks**:
|
||||
* [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ,
|
||||
* [tweak images via repeated Stable Diffusion seeds](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ,
|
||||
* [image-to-image generation with Stable Diffusion](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) ,
|
||||
* [tweak images via repeated Stable Diffusion seeds](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) ,
|
||||
|
||||
**Diffusers for Other Modalities**:
|
||||
* [Molecule conformation generation](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) ,
|
||||
* [Model-based reinforcement learning](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ,
|
||||
* [Molecule conformation generation](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/geodiff_molecule_conformation.ipynb) ,
|
||||
* [Model-based reinforcement learning](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb) ,
|
||||
|
||||
### Web Demos
|
||||
If you just want to play around with some web demos, you can try out the following 🚀 Spaces:
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
# docstyle-ignore
|
||||
INSTALL_CONTENT = """
|
||||
# Diffusers installation
|
||||
! pip install diffusers transformers datasets accelerate
|
||||
# To install from source instead of the last release, comment the command above and uncomment the following one.
|
||||
# ! pip install git+https://github.com/huggingface/diffusers.git
|
||||
"""
|
||||
|
||||
notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}]
|
||||
@@ -8,6 +8,10 @@
|
||||
- local: installation
|
||||
title: Installation
|
||||
title: Get started
|
||||
- sections:
|
||||
- local: tutorials/basic_training
|
||||
title: Train a diffusion model
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- sections:
|
||||
- local: using-diffusers/loading
|
||||
@@ -44,6 +48,8 @@
|
||||
title: How to contribute a Pipeline
|
||||
- local: using-diffusers/using_safetensors
|
||||
title: Using safetensors
|
||||
- local: using-diffusers/weighted_prompts
|
||||
title: Weighting Prompts
|
||||
title: Pipelines for Inference
|
||||
- sections:
|
||||
- local: using-diffusers/rl
|
||||
@@ -78,11 +84,11 @@
|
||||
- local: training/text_inversion
|
||||
title: Textual Inversion
|
||||
- local: training/dreambooth
|
||||
title: Dreambooth
|
||||
title: DreamBooth
|
||||
- local: training/text2image
|
||||
title: Text-to-image fine-tuning
|
||||
title: Text-to-image
|
||||
- local: training/lora
|
||||
title: LoRA Support in Diffusers
|
||||
title: Low-Rank Adaptation of Large Language Models (LoRA)
|
||||
title: Training
|
||||
- sections:
|
||||
- local: conceptual/philosophy
|
||||
|
||||
@@ -46,7 +46,7 @@ available a colab notebook to directly try them out.
|
||||
|---|---|:---:|:---:|
|
||||
| [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -
|
||||
| [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation |
|
||||
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/drive/1AiR7Q-sBqO88NCyswpfiuwXZc7DfMyKA?usp=sharing)
|
||||
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
|
||||
| [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
|
||||
@@ -33,7 +33,7 @@ Resources:
|
||||
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [StableDiffusionControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py) | *Text-to-Image Generation with ControlNet Conditioning* | [Colab Example](https://colab.research.google.com/drive/1AiR7Q-sBqO88NCyswpfiuwXZc7DfMyKA?usp=sharing) |
|
||||
| [StableDiffusionControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py) | *Text-to-Image Generation with ControlNet Conditioning* | [Colab Example](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb)
|
||||
|
||||
## Usage example
|
||||
|
||||
@@ -65,6 +65,12 @@ First, we need to install opencv:
|
||||
pip install opencv-contrib-python
|
||||
```
|
||||
|
||||
Next, let's also install all required Hugging Face libraries:
|
||||
|
||||
```
|
||||
pip install diffusers transformers git+https://github.com/huggingface/accelerate.git
|
||||
```
|
||||
|
||||
Then we can retrieve the canny edges of the image.
|
||||
|
||||
```python
|
||||
@@ -145,10 +151,11 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
|
||||
|[lllyasviel/sd-controlnet-hed](https://huggingface.co/lllyasviel/sd-controlnet-hed)<br/> *Trained with HED edge detection (soft edge)* |A monochrome image with white soft edges on a black background.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_bird_hed.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_bird_hed.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_bird_hed_1.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_bird_hed_1.png"/></a> |
|
||||
|[lllyasviel/sd-controlnet-mlsd](https://huggingface.co/lllyasviel/sd-controlnet-mlsd)<br/> *Trained with M-LSD line detection* |A monochrome image composed only of white straight lines on a black background.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_room_mlsd.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_room_mlsd.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_mlsd_0.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_mlsd_0.png"/></a>|
|
||||
|[lllyasviel/sd-controlnet-normal](https://huggingface.co/lllyasviel/sd-controlnet-normal)<br/> *Trained with normal map* |A [normal mapped](https://en.wikipedia.org/wiki/Normal_mapping) image.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_human_normal.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_human_normal.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_human_normal_1.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_human_normal_1.png"/></a>|
|
||||
|[lllyasviel/sd-controlnet_openpose](https://huggingface.co/lllyasviel/sd-controlnet_openpose)<br/> *Trained with OpenPose bone image* |A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_human_openpose.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_human_openpose.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_human_openpose_0.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_human_openpose_0.png"/></a>|
|
||||
|[lllyasviel/sd-controlnet_scribble](https://huggingface.co/lllyasviel/sd-controlnet_scribble)<br/> *Trained with human scribbles* |A hand-drawn monochrome image with white outlines on a black background.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_vermeer_scribble.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_vermeer_scribble.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_vermeer_scribble_0.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_vermeer_scribble_0.png"/></a> |
|
||||
|[lllyasviel/sd-controlnet_seg](https://huggingface.co/lllyasviel/sd-controlnet_seg)<br/>*Trained with semantic segmentation* |An [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/)'s segmentation protocol image.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_room_seg.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_room_seg.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_seg_1.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_seg_1.png"/></a> |
|
||||
|[lllyasviel/sd-controlnet-openpose](https://huggingface.co/lllyasviel/sd-controlnet_openpose)<br/> *Trained with OpenPose bone image* |A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_human_openpose.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_human_openpose.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_human_openpose_0.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_human_openpose_0.png"/></a>|
|
||||
|[lllyasviel/sd-controlnet-scribble](https://huggingface.co/lllyasviel/sd-controlnet_scribble)<br/> *Trained with human scribbles* |A hand-drawn monochrome image with white outlines on a black background.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_vermeer_scribble.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_vermeer_scribble.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_vermeer_scribble_0.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_vermeer_scribble_0.png"/></a> |
|
||||
|[lllyasviel/sd-controlnet-seg](https://huggingface.co/lllyasviel/sd-controlnet_seg)<br/>*Trained with semantic segmentation* |An [ADE20K](https://groups.csail.mit.edu/vision/datasets/ADE20K/)'s segmentation protocol image.|<a href="https://huggingface.co/takuma104/controlnet_dev/blob/main/gen_compare/control_images/converted/control_room_seg.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/control_images/converted/control_room_seg.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_seg_1.png"><img width="64" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare/output_images/diffusers/output_room_seg_1.png"/></a> |
|
||||
|
||||
## StableDiffusionControlNetPipeline
|
||||
[[autodoc]] StableDiffusionControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -36,7 +36,7 @@ available a colab notebook to directly try them out.
|
||||
|---|---|:---:|:---:|
|
||||
| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation |
|
||||
| [audio_diffusion](./api/pipelines/audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | [](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/audio_diffusion_pipeline.ipynb)
|
||||
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/drive/1AiR7Q-sBqO88NCyswpfiuwXZc7DfMyKA?usp=sharing)
|
||||
| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [
|
||||
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
|
||||
@@ -21,13 +21,13 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
## Stable Diffusion Inference
|
||||
|
||||
The snippet below demonstrates how to use the ONNX runtime. You need to use `StableDiffusionOnnxPipeline` instead of `StableDiffusionPipeline`. You also need to download the weights from the `onnx` branch of the repository, and indicate the runtime provider you want to use.
|
||||
The snippet below demonstrates how to use the ONNX runtime. You need to use `OnnxStableDiffusionPipeline` instead of `StableDiffusionPipeline`. You also need to download the weights from the `onnx` branch of the repository, and indicate the runtime provider you want to use.
|
||||
|
||||
```python
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from diffusers import StableDiffusionOnnxPipeline
|
||||
from diffusers import OnnxStableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionOnnxPipeline.from_pretrained(
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="onnx",
|
||||
provider="CUDAExecutionProvider",
|
||||
@@ -37,6 +37,37 @@ prompt = "a photo of an astronaut riding a horse on mars"
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
The snippet below demonstrates how to use the ONNX runtime with the Stable Diffusion upscaling pipeline.
|
||||
|
||||
```python
|
||||
from diffusers import OnnxStableDiffusionPipeline, OnnxStableDiffusionUpscalePipeline
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
steps = 50
|
||||
|
||||
txt2img = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="onnx",
|
||||
provider="CUDAExecutionProvider",
|
||||
)
|
||||
small_image = txt2img(
|
||||
prompt,
|
||||
num_inference_steps=steps,
|
||||
).images[0]
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
upscale = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
"ssube/stable-diffusion-x4-upscaler-onnx",
|
||||
provider="CUDAExecutionProvider",
|
||||
)
|
||||
large_image = upscale(
|
||||
prompt,
|
||||
small_image,
|
||||
generator=generator,
|
||||
num_inference_steps=steps,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Known Issues
|
||||
|
||||
- Generating multiple prompts in a batch seems to take too much memory. While we look into it, you may need to iterate instead of batching.
|
||||
|
||||
+230
-47
@@ -10,10 +10,25 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Quicktour
|
||||
|
||||
Get up and running with 🧨 Diffusers quickly!
|
||||
Whether you're a developer or an everyday user, this quick tour will help you get started and show you how to use [`DiffusionPipeline`] for inference.
|
||||
Diffusion models are trained to denoise random Gaussian noise step-by-step to generate a sample of interest, such as an image or audio. This has sparked a tremendous amount of interest in generative AI, and you have probably seen examples of diffusion generated images on the internet. 🧨 Diffusers is a library aimed at making diffusion models widely accessible to everyone.
|
||||
|
||||
Whether you're a developer or an everyday user, this quicktour will introduce you to 🧨 Diffusers and help you get up and generating quickly! There are three main components of the library to know about:
|
||||
|
||||
* The [`DiffusionPipeline`] is a high-level end-to-end class designed to rapidly generate samples from pretrained diffusion models for inference.
|
||||
* Popular pretrained [model](./api/models) architectures and modules that can be used as building blocks for creating diffusion systems.
|
||||
* Many different [schedulers](./api/schedulers/overview) - algorithms that control how noise is added for training, and how to generate denoised images during inference.
|
||||
|
||||
The quicktour will show you how to use the [`DiffusionPipeline`] for inference, and then walk you through how to combine a model and scheduler to replicate what's happening inside the [`DiffusionPipeline`].
|
||||
|
||||
<Tip>
|
||||
|
||||
The quicktour is a simplified version of the introductory 🧨 Diffusers [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) to help you get started quickly. If you want to learn more about 🧨 Diffusers goal, design philosophy, and additional details about it's core API, check out the notebook!
|
||||
|
||||
</Tip>
|
||||
|
||||
Before you begin, make sure you have all the necessary libraries installed:
|
||||
|
||||
@@ -21,32 +36,32 @@ Before you begin, make sure you have all the necessary libraries installed:
|
||||
pip install --upgrade diffusers accelerate transformers
|
||||
```
|
||||
|
||||
- [`accelerate`](https://huggingface.co/docs/accelerate/index) speeds up model loading for inference and training
|
||||
- [`transformers`](https://huggingface.co/docs/transformers/index) is required to run the most popular diffusion models, such as [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview)
|
||||
- [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) speeds up model loading for inference and training.
|
||||
- [🤗 Transformers](https://huggingface.co/docs/transformers/index) is required to run the most popular diffusion models, such as [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview).
|
||||
|
||||
## DiffusionPipeline
|
||||
|
||||
The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference. You can use the [`DiffusionPipeline`] out-of-the-box for many tasks across different modalities. Take a look at the table below for some supported tasks:
|
||||
The [`DiffusionPipeline`] is the easiest way to use a pretrained diffusion system for inference. It is an end-to-end system containing the model and the scheduler. You can use the [`DiffusionPipeline`] out-of-the-box for many tasks. Take a look at the table below for some supported tasks, and for a complete list of supported tasks, check out the [🧨 Diffusers Summary](./api/pipelines/overview#diffusers-summary) table.
|
||||
|
||||
| **Task** | **Description** | **Pipeline**
|
||||
|------------------------------|--------------------------------------------------------------------------------------------------------------|-----------------|
|
||||
| Unconditional Image Generation | generate an image from gaussian noise | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |
|
||||
| Unconditional Image Generation | generate an image from Gaussian noise | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |
|
||||
| Text-Guided Image Generation | generate an image given a text prompt | [conditional_image_generation](./using-diffusers/conditional_image_generation) |
|
||||
| Text-Guided Image-to-Image Translation | adapt an image guided by a text prompt | [img2img](./using-diffusers/img2img) |
|
||||
| Text-Guided Image-Inpainting | fill the masked part of an image given the image, the mask and a text prompt | [inpaint](./using-diffusers/inpaint) |
|
||||
| Text-Guided Depth-to-Image Translation | adapt parts of an image guided by a text prompt while preserving structure via depth estimation | [depth2img](./using-diffusers/depth2img) |
|
||||
|
||||
For more in-detail information on how diffusion pipelines function for the different tasks, please have a look at the [**Using Diffusers**](./using-diffusers/overview) section.
|
||||
Start by creating an instance of a [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
|
||||
You can use the [`DiffusionPipeline`] for any [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) stored on the Hugging Face Hub.
|
||||
In this quicktour, you'll load the [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) checkpoint for text-to-image generation.
|
||||
|
||||
As an example, start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
|
||||
You can use the [`DiffusionPipeline`] for any [Diffusers' checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads).
|
||||
In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generation with [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion).
|
||||
<Tip warning={true}>
|
||||
|
||||
For [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion), please carefully read its [license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) before running the model.
|
||||
This is due to the improved image generation capabilities of the model and the potentially harmful content that could be produced with it.
|
||||
Please, head over to your stable diffusion model of choice, *e.g.* [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5), and read the license.
|
||||
For [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) models, please carefully read the [license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) first before running the model. 🧨 Diffusers implements a [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) to prevent offensive or harmful content, but the model's improved image generation capabilities can still produce potentially harmful content.
|
||||
|
||||
You can load the model as follows:
|
||||
</Tip>
|
||||
|
||||
Load the model with the [`~DiffusionPipeline.from_pretrained`] method:
|
||||
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
@@ -54,77 +69,245 @@ You can load the model as follows:
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
```
|
||||
|
||||
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
|
||||
Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on GPU.
|
||||
You can move the generator object to GPU, just like you would in PyTorch.
|
||||
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. You'll see that the Stable Diffusion pipeline is composed of the [`UNet2DConditionModel`] and [`PNDMScheduler`] among other things:
|
||||
|
||||
```py
|
||||
>>> pipeline
|
||||
StableDiffusionPipeline {
|
||||
"_class_name": "StableDiffusionPipeline",
|
||||
"_diffusers_version": "0.13.1",
|
||||
...,
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"PNDMScheduler"
|
||||
],
|
||||
...,
|
||||
"unet": [
|
||||
"diffusers",
|
||||
"UNet2DConditionModel"
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
We strongly recommend running the pipeline on a GPU because the model consists of roughly 1.4 billion parameters.
|
||||
You can move the generator object to a GPU, just like you would in PyTorch:
|
||||
|
||||
```python
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
|
||||
Now you can use the `pipeline` on your text prompt:
|
||||
Now you can pass a text prompt to the `pipeline` to generate an image, and then access the denoised image. By default, the image output is wrapped in a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
|
||||
|
||||
```python
|
||||
>>> image = pipeline("An image of a squirrel in Picasso style").images[0]
|
||||
>>> image
|
||||
```
|
||||
|
||||
The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png"/>
|
||||
</div>
|
||||
|
||||
You can save the image by simply calling:
|
||||
Save the image by calling `save`:
|
||||
|
||||
```python
|
||||
>>> image.save("image_of_squirrel_painting.png")
|
||||
```
|
||||
|
||||
**Note**: You can also use the pipeline locally by downloading the weights via:
|
||||
### Local pipeline
|
||||
|
||||
You can also use the pipeline locally. The only difference is you need to download the weights first:
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||
```
|
||||
|
||||
and then loading the saved weights into the pipeline.
|
||||
Then load the saved weights into the pipeline:
|
||||
|
||||
```python
|
||||
>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
|
||||
```
|
||||
|
||||
Running the pipeline is then identical to the code above as it's the same model architecture.
|
||||
Now you can run the pipeline as you would in the section above.
|
||||
|
||||
```python
|
||||
>>> generator.to("cuda")
|
||||
>>> image = generator("An image of a squirrel in Picasso style").images[0]
|
||||
>>> image.save("image_of_squirrel_painting.png")
|
||||
```
|
||||
### Swapping schedulers
|
||||
|
||||
Diffusion systems can be used with multiple different [schedulers](./api/schedulers/overview) each with their
|
||||
pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to
|
||||
use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler,
|
||||
you could use it as follows:
|
||||
Different schedulers come with different denoising speeds and quality trade-offs. The best way to find out which one works best for you is to try them out! One of the main features of 🧨 Diffusers is to allow you to easily switch between schedulers. For example, to replace the default [`PNDMScheduler`] with the [`EulerDiscreteScheduler`], load it with the [`~diffusers.ConfigMixin.from_config`] method:
|
||||
|
||||
```python
|
||||
```py
|
||||
>>> from diffusers import EulerDiscreteScheduler
|
||||
|
||||
>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
>>> # change scheduler to Euler
|
||||
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
```
|
||||
|
||||
For more in-detail information on how to change between schedulers, please refer to the [Using Schedulers](./using-diffusers/schedulers) guide.
|
||||
Try generating an image with the new scheduler and see if you notice a difference!
|
||||
|
||||
[Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model
|
||||
and can do much more than just generating images from text. We have dedicated a whole documentation page,
|
||||
just for Stable Diffusion [here](./conceptual/stable_diffusion).
|
||||
In the next section, you'll take a closer look at the components - the model and scheduler - that make up the [`DiffusionPipeline`] and learn how to use these components to generate an image of a cat.
|
||||
|
||||
If you want to know how to optimize Stable Diffusion to run on less memory, higher inference speeds, on specific hardware, such as Mac, or with [ONNX Runtime](https://onnxruntime.ai/), please have a look at our
|
||||
optimization pages:
|
||||
## Models
|
||||
|
||||
- [Optimized PyTorch on GPU](./optimization/fp16)
|
||||
- [Mac OS with PyTorch](./optimization/mps)
|
||||
- [ONNX](./optimization/onnx)
|
||||
- [OpenVINO](./optimization/open_vino)
|
||||
Most models take a noisy sample, and at each timestep it predicts the *noise residual* (other models learn to predict the previous sample directly or the velocity or [`v-prediction`](https://github.com/huggingface/diffusers/blob/5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef/src/diffusers/schedulers/scheduling_ddim.py#L110)), the difference between a less noisy image and the input image. You can mix and match models to create other diffusion systems.
|
||||
|
||||
If you want to fine-tune or train your diffusion model, please have a look at the [**training section**](./training/overview)
|
||||
Models are initiated with the [`~ModelMixin.from_pretrained`] method which also locally caches the model weights so it is faster the next time you load the model. For the quicktour, you'll load the [`UNet2DModel`], a basic unconditional image generation model with a checkpoint trained on cat images:
|
||||
|
||||
Finally, please be considerate when distributing generated images publicly 🤗.
|
||||
```py
|
||||
>>> from diffusers import UNet2DModel
|
||||
|
||||
>>> repo_id = "google/ddpm-cat-256"
|
||||
>>> model = UNet2DModel.from_pretrained(repo_id)
|
||||
```
|
||||
|
||||
To access the model parameters, call `model.config`:
|
||||
|
||||
```py
|
||||
>>> model.config
|
||||
```
|
||||
|
||||
The model configuration is a 🧊 frozen 🧊 dictionary, which means those parameters can't be changed after the model is created. This is intentional and ensures that the parameters used to define the model architecture at the start remain the same, while other parameters can still be adjusted during inference.
|
||||
|
||||
Some of the most important parameters are:
|
||||
|
||||
* `sample_size`: the height and width dimension of the input sample.
|
||||
* `in_channels`: the number of input channels of the input sample.
|
||||
* `down_block_types` and `up_block_types`: the type of down- and upsampling blocks used to create the UNet architecture.
|
||||
* `block_out_channels`: the number of output channels of the downsampling blocks; also used in reverse order for the number of input channels of the upsampling blocks.
|
||||
* `layers_per_block`: the number of ResNet blocks present in each UNet block.
|
||||
|
||||
To use the model for inference, create the image shape with random Gaussian noise. It should have a `batch` axis because the model can receive multiple random noises, a `channel` axis corresponding to the number of input channels, and a `sample_size` axis for the height and width of the image:
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
|
||||
>>> torch.manual_seed(0)
|
||||
|
||||
>>> noisy_sample = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
>>> noisy_sample.shape
|
||||
torch.Size([1, 3, 256, 256])
|
||||
```
|
||||
|
||||
For inference, pass the noisy image to the model and a `timestep`. The `timestep` indicates how noisy the input image is, with more noise at the beginning and less at the end. This helps the model determine its position in the diffusion process, whether it is closer to the start or the end. Use the `sample` method to get the model output:
|
||||
|
||||
```py
|
||||
>>> with torch.no_grad():
|
||||
... noisy_residual = model(sample=noisy_sample, timestep=2).sample
|
||||
```
|
||||
|
||||
To generate actual examples though, you'll need a scheduler to guide the denoising process. In the next section, you'll learn how to couple a model with a scheduler.
|
||||
|
||||
## Schedulers
|
||||
|
||||
Schedulers manage going from a noisy sample to a less noisy sample given the model output - in this case, it is the `noisy_residual`.
|
||||
|
||||
<Tip>
|
||||
|
||||
🧨 Diffusers is a toolbox for building diffusion systems. While the [`DiffusionPipeline`] is a convenient way to get started with a pre-built diffusion system, you can also choose your own the model and scheduler components separately to build a custom diffusion system.
|
||||
|
||||
</Tip>
|
||||
|
||||
For the quicktour, you'll instantiate the [`DDPMScheduler`] with it's [`~diffusers.ConfigMixin.from_config`] method:
|
||||
|
||||
```py
|
||||
>>> from diffusers import DDPMScheduler
|
||||
|
||||
>>> scheduler = DDPMScheduler.from_config(repo_id)
|
||||
>>> scheduler
|
||||
DDPMScheduler {
|
||||
"_class_name": "DDPMScheduler",
|
||||
"_diffusers_version": "0.13.1",
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"beta_start": 0.0001,
|
||||
"clip_sample": true,
|
||||
"clip_sample_range": 1.0,
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"trained_betas": null,
|
||||
"variance_type": "fixed_small"
|
||||
}
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 Notice how the scheduler is instantiated from a configuration. Unlike a model, a scheduler does not have trainable weights and is parameter-free!
|
||||
|
||||
</Tip>
|
||||
|
||||
Some of the most important parameters are:
|
||||
|
||||
* `num_train_timesteps`: the length of the denoising process or in other words, the number of timesteps required to process random Gaussian noise into a data sample.
|
||||
* `beta_schedule`: the type of noise schedule to use for inference and training.
|
||||
* `beta_start` and `beta_end`: the start and end noise values for the noise schedule.
|
||||
|
||||
To predict a slightly less noisy image, pass the following to the scheduler's [`~diffusers.DDPMScheduler.step`] method: model output, `timestep`, and current `sample`.
|
||||
|
||||
```py
|
||||
>>> less_noisy_sample = scheduler.step(model_output=noisy_residual, timestep=2, sample=noisy_sample).prev_sample
|
||||
>>> less_noisy_sample.shape
|
||||
```
|
||||
|
||||
The `less_noisy_sample` can be passed to the next `timestep` where it'll get even less noisier! Let's bring it all together now and visualize the entire denoising process.
|
||||
|
||||
First, create a function that postprocesses and displays the denoised image as a `PIL.Image`:
|
||||
|
||||
```py
|
||||
>>> import PIL.Image
|
||||
>>> import numpy as np
|
||||
|
||||
|
||||
>>> def display_sample(sample, i):
|
||||
... image_processed = sample.cpu().permute(0, 2, 3, 1)
|
||||
... image_processed = (image_processed + 1.0) * 127.5
|
||||
... image_processed = image_processed.numpy().astype(np.uint8)
|
||||
|
||||
... image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
... display(f"Image at step {i}")
|
||||
... display(image_pil)
|
||||
```
|
||||
|
||||
To speed up the denoising process, move the input and model to a GPU:
|
||||
|
||||
```py
|
||||
>>> model.to("cuda")
|
||||
>>> noisy_sample = noisy_sample.to("cuda")
|
||||
```
|
||||
|
||||
Now create a denoising loop that predicts the residual of the less noisy sample, and computes the less noisy sample with the scheduler:
|
||||
|
||||
```py
|
||||
>>> import tqdm
|
||||
|
||||
>>> sample = noisy_sample
|
||||
|
||||
>>> for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):
|
||||
... # 1. predict noise residual
|
||||
... with torch.no_grad():
|
||||
... residual = model(sample, t).sample
|
||||
|
||||
... # 2. compute less noisy image and set x_t -> x_t-1
|
||||
... sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
... # 3. optionally look at image
|
||||
... if (i + 1) % 50 == 0:
|
||||
... display_sample(sample, i + 1)
|
||||
```
|
||||
|
||||
Sit back and watch as a cat is generated from nothing but noise! 😻
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/diffusion-quicktour.png"/>
|
||||
</div>
|
||||
|
||||
## Next steps
|
||||
|
||||
Hopefully you generated some cool images with 🧨 Diffusers in this quicktour! For your next steps, you can:
|
||||
|
||||
* Train or finetune a model to generate your own images in the [training](./tutorials/basic_training) tutorial.
|
||||
* See example official and community [training or finetuning scripts](https://github.com/huggingface/diffusers/tree/main/examples#-diffusers-examples) for a variety of use cases.
|
||||
* Learn more about loading, accessing, changing and comparing schedulers in the [Using different Schedulers](./using-diffusers/schedulers) guide.
|
||||
* Explore prompt engineering, speed and memory optimizations, and tips and tricks for generating higher quality images with the [Stable Diffusion](./stable_diffusion) guide.
|
||||
* Dive deeper into speeding up 🧨 Diffusers with guides on [optimized PyTorch on a GPU](./optimization/fp16), and inference guides for running [Stable Diffusion on Apple Silicon (M1/M2)](./optimization/mps) and [ONNX Runtime](./optimization/onnx).
|
||||
@@ -10,55 +10,67 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DreamBooth fine-tuning example
|
||||
# DreamBooth
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like stable diffusion given just a few (3~5) images of a subject.
|
||||
[[open-in-colab]]
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like Stable Diffusion given just a few (3-5) images of a subject. It allows the model to generate contextualized images of the subject in different scenes, poses, and views.
|
||||
|
||||

|
||||
_Dreambooth examples from the [project's blog](https://dreambooth.github.io)._
|
||||
<small>Dreambooth examples from the <a href="https://dreambooth.github.io">project's blog.</a></small>
|
||||
|
||||
The [Dreambooth training script](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) shows how to implement this training procedure on a pre-trained Stable Diffusion model.
|
||||
This guide will show you how to finetune DreamBooth with the [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) model for various GPU sizes, and with Flax. All the training scripts for DreamBooth used in this guide can be found [here](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) if you're interested in digging deeper and seeing how things work.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Dreambooth fine-tuning is very sensitive to hyperparameters and easy to overfit. We recommend you take a look at our [in-depth analysis](https://huggingface.co/blog/dreambooth) with recommended settings for different subjects, and go from there.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Training locally
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies. We also recommend to install `diffusers` from the `main` github branch.
|
||||
Before running the scripts, make sure you install the library's training dependencies. We also recommend installing 🧨 Diffusers from the `main` GitHub branch:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/diffusers
|
||||
pip install -U -r diffusers/examples/dreambooth/requirements.txt
|
||||
```
|
||||
|
||||
xFormers is not part of the training requirements, but [we recommend you install it if you can](../optimization/xformers). It could make your training faster and less memory intensive.
|
||||
xFormers is not part of the training requirements, but we recommend you [install](../optimization/xformers) it if you can because it could make your training faster and less memory intensive.
|
||||
|
||||
After all dependencies have been set up you can configure a [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
After all the dependencies have been set up, initialize a [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
In this example we'll use model version `v1-4`, so please visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4) and carefully read the license before proceeding.
|
||||
To setup a default 🤗 Accelerate environment without choosing any configurations:
|
||||
|
||||
The command below will download and cache the model weights from the Hub because we use the model's Hub id `CompVis/stable-diffusion-v1-4`. You may also clone the repo locally and use the local path in your system where the checkout was saved.
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
### Dog toy example
|
||||
Or if your environment doesn't support an interactive shell like a notebook, you can use:
|
||||
|
||||
In this example we'll use [these images](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) to add a new concept to Stable Diffusion using the Dreambooth process. They will be our training data. Please, download them and place them somewhere in your system.
|
||||
```py
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
Then you can launch the training script using:
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
## Finetuning
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
DreamBooth finetuning is very sensitive to hyperparameters and easy to overfit. We recommend you take a look at our [in-depth analysis](https://huggingface.co/blog/dreambooth) with recommended settings for different subjects to help you choose the appropriate hyperparameters.
|
||||
|
||||
</Tip>
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
Let's try DreamBooth with a [few images of a dog](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ); download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path_to_training_images"
|
||||
export OUTPUT_DIR="path_to_saved_model"
|
||||
```
|
||||
|
||||
Then you can launch the training script (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py)) with the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
@@ -72,13 +84,44 @@ accelerate launch train_dreambooth.py \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=400
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
If you have access to TPUs or want to train even faster, you can try out the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_flax.py). The Flax training script doesn't support gradient checkpointing or gradient accumulation, so you'll need a GPU with at least 30GB of memory.
|
||||
|
||||
### Training with a prior-preserving loss
|
||||
Before running the script, make sure you have the requirements installed:
|
||||
|
||||
Prior preservation is used to avoid overfitting and language-drift. Please, refer to the paper to learn more about it if you are interested. For prior preservation, we use other images of the same class as part of the training process. The nice thing is that we can generate those images using the Stable Diffusion model itself! The training script will save the generated images to a local path we specify.
|
||||
```bash
|
||||
pip install -U -r requirements.txt
|
||||
```
|
||||
|
||||
According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior preservation. 200-300 works well for most cases.
|
||||
Now you can launch the training script with the following command:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
python train_dreambooth_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--learning_rate=5e-6 \
|
||||
--max_train_steps=400
|
||||
```
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
## Finetuning with prior-preserving loss
|
||||
|
||||
Prior preservation is used to avoid overfitting and language-drift (check out the [paper](https://arxiv.org/abs/2208.12242) to learn more if you're interested). For prior preservation, you use other images of the same class as part of the training process. The nice thing is that you can generate those images using the Stable Diffusion model itself! The training script will save the generated images to a local path you specify.
|
||||
|
||||
The author's recommend generating `num_epochs * num_samples` images for prior preservation. In most cases, 200-300 images work well.
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path_to_training_images"
|
||||
@@ -102,32 +145,125 @@ accelerate launch train_dreambooth.py \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
### Saving checkpoints while training
|
||||
python train_dreambooth_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--class_data_dir=$CLASS_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--with_prior_preservation --prior_loss_weight=1.0 \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--class_prompt="a photo of dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--learning_rate=5e-6 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
```
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
It's easy to overfit while training with Dreambooth, so sometimes it's useful to save regular checkpoints during the process. One of the intermediate checkpoints might work better than the final model! To use this feature you need to pass the following argument to the training script:
|
||||
## Finetuning the text encoder and UNet
|
||||
|
||||
The script also allows you to finetune the `text_encoder` along with the `unet`. In our experiments (check out the [Training Stable Diffusion with DreamBooth using 🧨 Diffusers](https://huggingface.co/blog/dreambooth) post for more details), this yields much better results, especially when generating images of faces.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Training the text encoder requires additional memory and it won't fit on a 16GB GPU. You'll need at least 24GB VRAM to use this option.
|
||||
|
||||
</Tip>
|
||||
|
||||
Pass the `--train_text_encoder` argument to the training script to enable finetuning the `text_encoder` and `unet`:
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path_to_training_images"
|
||||
export CLASS_DIR="path_to_class_images"
|
||||
export OUTPUT_DIR="path_to_saved_model"
|
||||
|
||||
accelerate launch train_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_text_encoder \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--class_data_dir=$CLASS_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--with_prior_preservation --prior_loss_weight=1.0 \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--class_prompt="a photo of dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--use_8bit_adam
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=2e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
python train_dreambooth_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_text_encoder \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--class_data_dir=$CLASS_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--with_prior_preservation --prior_loss_weight=1.0 \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--class_prompt="a photo of dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--learning_rate=2e-6 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
```
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
## Finetuning with LoRA
|
||||
|
||||
You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, on DreamBooth. For more details, take a look at the [LoRA training](training/lora#dreambooth) guide.
|
||||
|
||||
## Saving checkpoints while training
|
||||
|
||||
It's easy to overfit while training with Dreambooth, so sometimes it's useful to save regular checkpoints during the training process. One of the intermediate checkpoints might actually work better than the final model! Pass the following argument to the training script to enable saving checkpoints:
|
||||
|
||||
```bash
|
||||
--checkpointing_steps=500
|
||||
```
|
||||
|
||||
This will save the full training state in subfolders of your `output_dir`. Subfolder names begin with the prefix `checkpoint-`, and then the number of steps performed so far; for example: `checkpoint-1500` would be a checkpoint saved after 1500 training steps.
|
||||
This saves the full training state in subfolders of your `output_dir`. Subfolder names begin with the prefix `checkpoint-`, followed by the number of steps performed so far; for example, `checkpoint-1500` would be a checkpoint saved after 1500 training steps.
|
||||
|
||||
#### Resuming training from a saved checkpoint
|
||||
### Resume training from a saved checkpoint
|
||||
|
||||
If you want to resume training from any of the saved checkpoints, you can pass the argument `--resume_from_checkpoint` and then indicate the name of the checkpoint you want to use. You can also use the special string `"latest"` to resume from the last checkpoint saved (i.e., the one with the largest number of steps). For example, the following would resume training from the checkpoint saved after 1500 steps:
|
||||
If you want to resume training from any of the saved checkpoints, you can pass the argument `--resume_from_checkpoint` to the script and specify the name of the checkpoint you want to use. You can also use the special string `"latest"` to resume from the last saved checkpoint (the one with the largest number of steps). For example, the following would resume training from the checkpoint saved after 1500 steps:
|
||||
|
||||
```bash
|
||||
--resume_from_checkpoint="checkpoint-1500"
|
||||
```
|
||||
|
||||
This would be a good opportunity to tweak some of your hyperparameters if you wish.
|
||||
This is a good opportunity to tweak some of your hyperparameters if you wish.
|
||||
|
||||
#### Performing inference using a saved checkpoint
|
||||
### Inference from a saved checkpoint
|
||||
|
||||
Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate.
|
||||
Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders, and learning rate.
|
||||
|
||||
**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run
|
||||
If you have **`"accelerate>=0.16.0"`** installed, use the following code to run
|
||||
inference from an intermediate checkpoint.
|
||||
|
||||
```python
|
||||
@@ -150,7 +286,7 @@ pipeline.to("cuda")
|
||||
pipeline.save_pretrained("dreambooth-pipeline")
|
||||
```
|
||||
|
||||
If you have installed `"accelerate<0.16.0"` you need to first convert it to an inference pipeline. This is how you could do it:
|
||||
If you have **`"accelerate<0.16.0"`** installed, you need to convert it to an inference pipeline first:
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
@@ -179,15 +315,37 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
pipeline.save_pretrained("dreambooth-pipeline")
|
||||
```
|
||||
|
||||
### Training on a 16GB GPU
|
||||
## Optimizations for different GPU sizes
|
||||
|
||||
With the help of gradient checkpointing and the 8-bit optimizer from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes), it's possible to train dreambooth on a 16GB GPU.
|
||||
Depending on your hardware, there are a few different ways to optimize DreamBooth on GPUs from 16GB to just 8GB!
|
||||
|
||||
### xFormers
|
||||
|
||||
[xFormers](https://github.com/facebookresearch/xformers) is a toolbox for optimizing Transformers, and it include a [memory-efficient attention](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops) mechanism that is used in 🧨 Diffusers. You'll need to [install xFormers](./optimization/xformers) and then add the following argument to your training script:
|
||||
|
||||
```bash
|
||||
--enable_xformers_memory_efficient_attention
|
||||
```
|
||||
|
||||
xFormers is not available in Flax.
|
||||
|
||||
### Set gradients to none
|
||||
|
||||
Another way you can lower your memory footprint is to [set the gradients](https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html) to `None` instead of zero. However, this may change certain behaviors, so if you run into any issues, try removing this argument. Add the following argument to your training script to set the gradients to `None`:
|
||||
|
||||
```bash
|
||||
--set_grads_to_none
|
||||
```
|
||||
|
||||
### 16GB GPU
|
||||
|
||||
With the help of gradient checkpointing and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) 8-bit optimizer, it's possible to train DreamBooth on a 16GB GPU. Make sure you have bitsandbytes installed:
|
||||
|
||||
```bash
|
||||
pip install bitsandbytes
|
||||
```
|
||||
|
||||
Then pass the `--use_8bit_adam` option to the training script.
|
||||
Then pass the `--use_8bit_adam` option to the training script:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
@@ -214,25 +372,18 @@ accelerate launch train_dreambooth.py \
|
||||
--max_train_steps=800
|
||||
```
|
||||
|
||||
### Fine-tune the text encoder in addition to the UNet
|
||||
### 12GB GPU
|
||||
|
||||
The script also allows to fine-tune the `text_encoder` along with the `unet`. It has been observed experimentally that this gives much better results, especially on faces. Please, refer to [our blog](https://huggingface.co/blog/dreambooth) for more details.
|
||||
|
||||
To enable this option, pass the `--train_text_encoder` argument to the training script.
|
||||
|
||||
<Tip>
|
||||
Training the text encoder requires additional memory, so training won't fit on a 16GB GPU. You'll need at least 24GB VRAM to use this option.
|
||||
</Tip>
|
||||
To run DreamBooth on a 12GB GPU, you'll need to enable gradient checkpointing, the 8-bit optimizer, xFormers, and set the gradients to `None`:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path_to_training_images"
|
||||
export CLASS_DIR="path_to_class_images"
|
||||
export OUTPUT_DIR="path_to_saved_model"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_text_encoder \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--class_data_dir=$CLASS_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
@@ -241,8 +392,10 @@ accelerate launch train_dreambooth.py \
|
||||
--class_prompt="a photo of dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--use_8bit_adam
|
||||
--gradient_checkpointing \
|
||||
--gradient_accumulation_steps=1 --gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
--set_grads_to_none \
|
||||
--learning_rate=2e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
@@ -250,19 +403,25 @@ accelerate launch train_dreambooth.py \
|
||||
--max_train_steps=800
|
||||
```
|
||||
|
||||
### Training on a 8 GB GPU:
|
||||
### 8 GB GPU
|
||||
|
||||
Using [DeepSpeed](https://www.deepspeed.ai/) it's even possible to offload some
|
||||
tensors from VRAM to either CPU or NVME, allowing training to proceed with less GPU memory.
|
||||
For 8GB GPUs, you'll need the help of [DeepSpeed](https://www.deepspeed.ai/) to offload some
|
||||
tensors from the VRAM to either the CPU or NVME, enabling training with less GPU memory.
|
||||
|
||||
DeepSpeed needs to be enabled with `accelerate config`. During configuration,
|
||||
answer yes to "Do you want to use DeepSpeed?". Combining DeepSpeed stage 2, fp16
|
||||
mixed precision, and offloading both the model parameters and the optimizer state to CPU, it's
|
||||
possible to train on under 8 GB VRAM. The drawback is that this requires more system RAM (about 25 GB). See [the DeepSpeed documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more configuration options.
|
||||
Run the following command to configure your 🤗 Accelerate environment:
|
||||
|
||||
Changing the default Adam optimizer to DeepSpeed's special version of Adam
|
||||
`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup, but enabling
|
||||
it requires the system's CUDA toolchain version to be the same as the one installed with PyTorch. 8-bit optimizers don't seem to be compatible with DeepSpeed at the moment.
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
During configuration, confirm that you want to use DeepSpeed. Now it's possible to train on under 8GB VRAM by combining DeepSpeed stage 2, fp16 mixed precision, and offloading the model parameters and the optimizer state to the CPU. The drawback is that this requires more system RAM, about 25 GB. See [the DeepSpeed documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more configuration options.
|
||||
|
||||
You should also change the default Adam optimizer to DeepSpeed's optimized version of Adam
|
||||
[`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu) for a substantial speedup. Enabling `DeepSpeedCPUAdam` requires your system's CUDA toolchain version to be the same as the one installed with PyTorch.
|
||||
|
||||
8-bit optimizers don't seem to be compatible with DeepSpeed at the moment.
|
||||
|
||||
Launch training with the following command:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
@@ -292,11 +451,10 @@ accelerate launch train_dreambooth.py \
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a model, inference can be done using the `StableDiffusionPipeline`, by simply indicating the path where the model was saved. Make sure that your prompts include the special `identifier` used during training (`sks` in the previous examples).
|
||||
|
||||
**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run
|
||||
inference from an intermediate checkpoint.
|
||||
Once you have trained a model, specify the path to where the model is saved, and use it for inference in the [`StableDiffusionPipeline`]. Make sure your prompts include the special `identifier` used during training (`sks` in the previous examples).
|
||||
|
||||
If you have **`"accelerate>=0.16.0"`** installed, you can use the following code to run
|
||||
inference from an intermediate checkpoint:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
@@ -311,4 +469,4 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
image.save("dog-bucket.png")
|
||||
```
|
||||
|
||||
You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint).
|
||||
You may also run inference from any of the [saved training checkpoints](#inference-from-a-saved-checkpoint).
|
||||
+155
-119
@@ -10,54 +10,151 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LoRA Support in Diffusers
|
||||
# Low-Rank Adaptation of Large Language Models (LoRA)
|
||||
|
||||
Diffusers supports LoRA for faster fine-tuning of Stable Diffusion, allowing greater memory efficiency and easier portability.
|
||||
[[open-in-colab]]
|
||||
|
||||
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in
|
||||
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
|
||||
<Tip warning={true}>
|
||||
|
||||
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition weight matrices (called **update matrices**)
|
||||
to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
||||
|
||||
- Previous pretrained weights are kept frozen so that the model is not so prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
|
||||
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA matrices are generally added to the attention layers of the original model and they control to which extent the model is adapted toward new training images via a `scale` parameter.
|
||||
|
||||
**__Note that the usage of LoRA is not just limited to attention layers. In the original LoRA work, the authors found out that just amending
|
||||
the attention layers of a language model is sufficient to obtain good downstream performance with great efficiency. This is why, it's common
|
||||
to just add the LoRA weights to the attention layers of a model.__**
|
||||
|
||||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
|
||||
|
||||
<Tip>
|
||||
|
||||
LoRA allows us to achieve greater memory efficiency since the pretrained weights are kept frozen and only the LoRA weights are trained, thereby
|
||||
allowing us to run fine-tuning on consumer GPUs like Tesla T4, RTX 3080 or even RTX 2080 Ti! One can get access to GPUs like T4 in the free
|
||||
tiers of Kaggle Kernels and Google Colab Notebooks.
|
||||
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`].
|
||||
|
||||
</Tip>
|
||||
|
||||
## Getting started with LoRA for fine-tuning
|
||||
[Low-Rank Adaptation of Large Language Models (LoRA)](https://arxiv.org/abs/2106.09685) is a training method that accelerates the training of large models while consuming less memory. It adds pairs of rank-decomposition weight matrices (called **update matrices**) to existing weights, and **only** trains those newly added weights. This has a couple of advantages:
|
||||
|
||||
Stable Diffusion can be fine-tuned in different ways:
|
||||
- Previous pretrained weights are kept frozen so the model is not as prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
|
||||
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA matrices are generally added to the attention layers of the original model. 🧨 Diffusers provides the [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method to load the LoRA weights into a model's attention layers. You can control the extent to which the model is adapted toward new training images via a `scale` parameter.
|
||||
- The greater memory-efficiency allows you to run fine-tuning on consumer GPUs like the Tesla T4, RTX 3080 or even the RTX 2080 Ti! GPUs like the T4 are free and readily accessible in Kaggle or Google Colab notebooks.
|
||||
|
||||
* [Textual inversion](https://huggingface.co/docs/diffusers/main/en/training/text_inversion)
|
||||
* [DreamBooth](https://huggingface.co/docs/diffusers/main/en/training/dreambooth)
|
||||
* [Text2Image fine-tuning](https://huggingface.co/docs/diffusers/main/en/training/text2image)
|
||||
<Tip>
|
||||
|
||||
We provide two end-to-end examples that show how to run fine-tuning with LoRA:
|
||||
💡 LoRA is not only limited to attention layers. The authors found that amending
|
||||
the attention layers of a language model is sufficient to obtain good downstream performance with great efficiency. This is why it's common to just add the LoRA weights to the attention layers of a model. Check out the [Using LoRA for efficient Stable Diffusion fine-tuning](https://huggingface.co/blog/lora) blog for more information about how LoRA works!
|
||||
|
||||
* [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora)
|
||||
* [Text2Image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora)
|
||||
</Tip>
|
||||
|
||||
If you want to perform DreamBooth training with LoRA, for instance, you would run:
|
||||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. 🧨 Diffusers now supports finetuning with LoRA for [text-to-image generation](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) and [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora). This guide will show you how to do both.
|
||||
|
||||
If you'd like to store or share your model with the community, login to your Hugging Face account (create [one](hf.co/join) if you don't have one already):
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## Text-to-image
|
||||
|
||||
Finetuning a model like Stable Diffusion, which has billions of parameters, can be slow and difficult. With LoRA, it is much easier and faster to finetune a diffusion model. It can run on hardware with as little as 11GB of GPU RAM without resorting to tricks such as 8-bit optimizers.
|
||||
|
||||
### Training[[text-to-image-training]]
|
||||
|
||||
Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate your own Pokémon.
|
||||
|
||||
To start, make sure you have the `MODEL_NAME` and `DATASET_NAME` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables are optional and specify where to save the model to on the Hub:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="/sddata/finetune/lora/pokemon"
|
||||
export HUB_MODEL_ID="pokemon-lora"
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
```
|
||||
|
||||
There are some flags to be aware of before you start training:
|
||||
|
||||
* `--push_to_hub` stores the trained LoRA embeddings on the Hub.
|
||||
* `--report_to=wandb` reports and logs the training results to your Weights & Biases dashboard (as an example, take a look at this [report](https://wandb.ai/pcuenq/text2image-fine-tune/runs/b4k1w0tn?workspace=user-pcuenq)).
|
||||
* `--learning_rate=1e-04`, you can afford to use a higher learning rate than you normally would with LoRA.
|
||||
|
||||
Now you're ready to launch the training (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)):
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--dataloader_num_workers=8 \
|
||||
--resolution=512 --center_crop --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-04 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="cosine" --lr_warmup_steps=0 \
|
||||
--output_dir=${OUTPUT_DIR} \
|
||||
--push_to_hub \
|
||||
--hub_model_id=${HUB_MODEL_ID} \
|
||||
--report_to=wandb \
|
||||
--checkpointing_steps=500 \
|
||||
--validation_prompt="A pokemon with blue eyes." \
|
||||
--seed=1337
|
||||
```
|
||||
|
||||
### Inference[[text-to-image-inference]]
|
||||
|
||||
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`] and then the [`DPMSolverMultistepScheduler`]:
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
>>> model_base = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
>>> pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
|
||||
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
```
|
||||
|
||||
Load the LoRA weights from your finetuned model *on top of the base model weights*, and then move the pipeline to a GPU for faster inference. When you merge the LoRA weights with the frozen pretrained model weights, you can optionally adjust how much of the weights to merge with the `scale` parameter:
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 A `scale` value of `0` is the same as not using your LoRA weights and you're only using the base model weights, and a `scale` value of `1` means you're only using the fully finetuned LoRA weights. Values between `0` and `1` interpolates between the two weights.
|
||||
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
>>> pipe.unet.load_attn_procs(model_path)
|
||||
>>> pipe.to("cuda")
|
||||
# use half the weights from the LoRA finetuned model and half the weights from the base model
|
||||
|
||||
>>> image = pipe(
|
||||
... "A pokemon with blue eyes.", num_inference_steps=25, guidance_scale=7.5, cross_attention_kwargs={"scale": 0.5}
|
||||
... ).images[0]
|
||||
# use the weights from the fully finetuned LoRA model
|
||||
|
||||
>>> image = pipe("A pokemon with blue eyes.", num_inference_steps=25, guidance_scale=7.5).images[0]
|
||||
>>> image.save("blue_pokemon.png")
|
||||
```
|
||||
|
||||
## DreamBooth
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type).
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 Take a look at the [Training Stable Diffusion with DreamBooth using 🧨 Diffusers](https://huggingface.co/blog/dreambooth) blog for an in-depth analysis of DreamBooth experiments and recommended settings.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Training[[dreambooth-training]]
|
||||
|
||||
Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) with DreamBooth and LoRA with some 🐶 [dog images](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ). Download and save these images to a directory.
|
||||
|
||||
To start, make sure you have the `MODEL_NAME` and `INSTANCE_DIR` (path to directory containing images) environment variables set. The `OUTPUT_DIR` variables is optional and specifies where to save the model to on the Hub:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
```
|
||||
|
||||
There are some flags to be aware of before you start training:
|
||||
|
||||
* `--push_to_hub` stores the trained LoRA embeddings on the Hub.
|
||||
* `--report_to=wandb` reports and logs the training results to your Weights & Biases dashboard (as an example, take a look at this [report](https://wandb.ai/pcuenq/text2image-fine-tune/runs/b4k1w0tn?workspace=user-pcuenq)).
|
||||
* `--learning_rate=1e-04`, you can afford to use a higher learning rate than you normally would with LoRA.
|
||||
|
||||
Now you're ready to launch the training (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py)):
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
@@ -78,101 +175,40 @@ accelerate launch train_dreambooth_lora.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
A similar process can be followed to fully fine-tune Stable Diffusion on a custom dataset using the
|
||||
`examples/text_to_image/train_text_to_image_lora.py` script.
|
||||
### Inference[[dreambooth-inference]]
|
||||
|
||||
Refer to the respective examples linked above to learn more.
|
||||
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]:
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> model_base = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
>>> pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
Load the LoRA weights from your finetuned DreamBooth model *on top of the base model weights*, and then move the pipeline to a GPU for faster inference. When you merge the LoRA weights with the frozen pretrained model weights, you can optionally adjust how much of the weights to merge with the `scale` parameter:
|
||||
|
||||
<Tip>
|
||||
|
||||
When using LoRA we can use a much higher learning rate (typically 1e-4 as opposed to ~1e-6) compared to non-LoRA Dreambooth fine-tuning.
|
||||
💡 A `scale` value of `0` is the same as not using your LoRA weights and you're only using the base model weights, and a `scale` value of `1` means you're only using the fully finetuned LoRA weights. Values between `0` and `1` interpolates between the two weights.
|
||||
|
||||
</Tip>
|
||||
|
||||
But there is no free lunch. For the given dataset and expected generation quality, you'd still need to experiment with
|
||||
different hyperparameters. Here are some important ones:
|
||||
|
||||
* Training time
|
||||
* Learning rate
|
||||
* Number of training steps
|
||||
* Inference time
|
||||
* Number of steps
|
||||
* Scheduler type
|
||||
|
||||
Additionally, you can follow [this blog](https://huggingface.co/blog/dreambooth) that documents some of our experimental
|
||||
findings for performing DreamBooth training of Stable Diffusion.
|
||||
|
||||
When fine-tuning, the LoRA update matrices are only added to the attention layers. To enable this, we added new weight
|
||||
loading functionalities. Their details are available [here](https://huggingface.co/docs/diffusers/main/en/api/loaders).
|
||||
|
||||
## Inference
|
||||
|
||||
Assuming you used the `examples/text_to_image/train_text_to_image_lora.py` to fine-tune Stable Diffusion on the [Pokemon
|
||||
dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions), you can perform inference like so:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
|
||||
model_path = "sayakpaul/sd-model-finetuned-lora-t4"
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
|
||||
pipe.unet.load_attn_procs(model_path)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A pokemon with blue eyes."
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
||||
image.save("pokemon.png")
|
||||
```
|
||||
|
||||
Here are some example images you can expect:
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pokemon-collage.png"/>
|
||||
|
||||
[`sayakpaul/sd-model-finetuned-lora-t4`](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4) contains [LoRA fine-tuned update matrices](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin)
|
||||
which is only 3 MBs in size. During inference, the pre-trained Stable Diffusion checkpoints are loaded alongside these update
|
||||
matrices and then they are combined to run inference.
|
||||
|
||||
You can use the [`huggingface_hub`](https://github.com/huggingface/huggingface_hub) library to retrieve the base model
|
||||
from [`sayakpaul/sd-model-finetuned-lora-t4`](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4) like so:
|
||||
|
||||
```py
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
>>> pipe.unet.load_attn_procs(model_path)
|
||||
>>> pipe.to("cuda")
|
||||
# use half the weights from the LoRA finetuned model and half the weights from the base model
|
||||
|
||||
card = RepoCard.load("sayakpaul/sd-model-finetuned-lora-t4")
|
||||
base_model = card.data.to_dict()["base_model"]
|
||||
# 'CompVis/stable-diffusion-v1-4'
|
||||
```
|
||||
>>> image = pipe(
|
||||
... "A picture of a sks dog in a bucket.",
|
||||
... num_inference_steps=25,
|
||||
... guidance_scale=7.5,
|
||||
... cross_attention_kwargs={"scale": 0.5},
|
||||
... ).images[0]
|
||||
# use the weights from the fully finetuned LoRA model
|
||||
|
||||
And then you can use `pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16)`.
|
||||
|
||||
This is especially useful when you don't want to hardcode the base model identifier during initializing the `StableDiffusionPipeline`.
|
||||
|
||||
Inference for DreamBooth training remains the same. Check
|
||||
[this section](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#inference-1) for more details.
|
||||
|
||||
### Merging LoRA with original model
|
||||
|
||||
When performing inference, you can merge the trained LoRA weights with the frozen pre-trained model weights, to interpolate between the original model's inference result (as if no fine-tuning had occurred) and the fully fine-tuned version.
|
||||
|
||||
You can adjust the merging ratio with a parameter called α (alpha) in the paper, or `scale` in our implementation. You can tweak it with the following code, that passes `scale` as `cross_attention_kwargs` in the pipeline call:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
|
||||
model_path = "sayakpaul/sd-model-finetuned-lora-t4"
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
|
||||
pipe.unet.load_attn_procs(model_path)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A pokemon with blue eyes."
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5, cross_attention_kwargs={"scale": 0.5}).images[0]
|
||||
image.save("pokemon.png")
|
||||
```
|
||||
|
||||
A value of `0` is the same as _not_ using the LoRA weights, whereas `1` means only the LoRA fine-tuned weights will be used. Values between 0 and 1 will interpolate between the two versions.
|
||||
|
||||
|
||||
## Known limitations
|
||||
|
||||
* Currently, we only support LoRA for the attention layers of [`UNet2DConditionModel`](https://huggingface.co/docs/diffusers/main/en/api/models#diffusers.UNet2DConditionModel).
|
||||
>>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0]
|
||||
>>> image.save("bucket-dog.png")
|
||||
```
|
||||
@@ -11,20 +11,15 @@ specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
|
||||
# Stable Diffusion text-to-image fine-tuning
|
||||
|
||||
The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) script shows how to fine-tune the stable diffusion model on your own dataset.
|
||||
# Text-to-image
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The text-to-image fine-tuning script is experimental. It's easy to overfit and run into issues like catastrophic forgetting. We recommend to explore different hyperparameters to get the best results on your dataset.
|
||||
The text-to-image fine-tuning script is experimental. It's easy to overfit and run into issues like catastrophic forgetting. We recommend you explore different hyperparameters to get the best results on your dataset.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## Running locally
|
||||
|
||||
### Installing the dependencies
|
||||
Text-to-image models like Stable Diffusion generate an image from a text prompt. This guide will show you how to finetune the [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) model on your own dataset with PyTorch and Flax. All the training scripts for text-to-image finetuning used in this guide can be found in this [repository](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) if you're interested in taking a closer look.
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
@@ -33,32 +28,51 @@ pip install git+https://github.com/huggingface/diffusers.git
|
||||
pip install -U -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
Run the following command to authenticate your token
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
If you have already cloned the repo, then you won't need to go through these steps. Instead, you can pass the path to your local checkout to the training script and it will be loaded from there.
|
||||
|
||||
### Hardware Requirements for Fine-tuning
|
||||
## Hardware requirements
|
||||
|
||||
Using `gradient_checkpointing` and `mixed_precision` it should be possible to fine tune the model on a single 24GB GPU. For higher `batch_size` and faster training it's better to use GPUs with more than 30GB of GPU memory. You can also use JAX / Flax for fine-tuning on TPUs or GPUs, see [below](#flax-jax-finetuning) for details.
|
||||
Using `gradient_checkpointing` and `mixed_precision`, it should be possible to finetune the model on a single 24GB GPU. For higher `batch_size`'s and faster training, it's better to use GPUs with more than 30GB of GPU memory. You can also use JAX/Flax for fine-tuning on TPUs or GPUs, which will be covered [below](#flax-jax-finetuning).
|
||||
|
||||
### Fine-tuning Example
|
||||
You can reduce your memory footprint even more by enabling memory efficient attention with xFormers. Make sure you have [xFormers installed](./optimization/xformers) and pass the `--enable_xformers_memory_efficient_attention` flag to the training script.
|
||||
|
||||
The following script will launch a fine-tuning run using [Justin Pinkneys' captioned Pokemon dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions), available in Hugging Face Hub.
|
||||
xFormers is not available for Flax.
|
||||
|
||||
## Upload model to Hub
|
||||
|
||||
Store your model on the Hub by adding the following argument to the training script:
|
||||
|
||||
```bash
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Save and load checkpoints
|
||||
|
||||
It is a good idea to regularly save checkpoints in case anything happens during training. To save a checkpoint, pass the following argument to the training script:
|
||||
|
||||
```bash
|
||||
--checkpointing_steps=500
|
||||
```
|
||||
|
||||
Every 500 steps, the full training state is saved in a subfolder in the `output_dir`. The checkpoint has the format `checkpoint-` followed by the number of steps trained so far. For example, `checkpoint-1500` is a checkpoint saved after 1500 training steps.
|
||||
|
||||
To load a checkpoint to resume training, pass the argument `--resume_from_checkpoint` to the training script and specify the checkpoint you want to resume from. For example, the following argument resumes training from the checkpoint saved after 1500 training steps:
|
||||
|
||||
```bash
|
||||
--resume_from_checkpoint="checkpoint-1500"
|
||||
```
|
||||
|
||||
## Fine-tuning
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
Launch the [PyTorch training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) for a fine-tuning run on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset like this:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
@@ -80,9 +94,9 @@ accelerate launch train_text_to_image.py \
|
||||
--output_dir="sd-pokemon-model"
|
||||
```
|
||||
|
||||
To run on your own training files you need to prepare the dataset according to the format required by `datasets`. You can upload your dataset to the Hub, or you can prepare a local folder with your files. [This documentation](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata) explains how to do it.
|
||||
To finetune on your own dataset, prepare the dataset according to the format required by 🤗 [Datasets](https://huggingface.co/docs/datasets/index). You can [upload your dataset to the Hub](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub), or you can [prepare a local folder with your files](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
|
||||
|
||||
You should modify the script if you wish to use custom loading logic. We have left pointers in the code in the appropriate places :)
|
||||
Modify the script if you want to use custom loading logic. We left pointers in the code in the appropriate places to help you. 🤗 The example script below shows how to finetune on a local dataset in `TRAIN_DIR` and where to save the model to in `OUTPUT_DIR`:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
@@ -104,25 +118,19 @@ accelerate launch train_text_to_image.py \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--output_dir=${OUTPUT_DIR}
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
With Flax, it's possible to train a Stable Diffusion model faster on TPUs and GPUs thanks to [@duongna211](https://github.com/duongna21). This is very efficient on TPU hardware but works great on GPUs too. The Flax training script doesn't support features like gradient checkpointing or gradient accumulation yet, so you'll need a GPU with at least 30GB of memory or a TPU v3.
|
||||
|
||||
Once training is finished the model will be saved to the `OUTPUT_DIR` specified in the command. To load the fine-tuned model for inference, just pass that path to `StableDiffusionPipeline`:
|
||||
Before running the script, make sure you have the requirements installed:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
model_path = "path_to_saved_model"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(prompt="yoda").images[0]
|
||||
image.save("yoda-pokemon.png")
|
||||
```bash
|
||||
pip install -U -r requirements_flax.txt
|
||||
```
|
||||
|
||||
### Flax / JAX fine-tuning
|
||||
Now you can launch the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py) like this:
|
||||
|
||||
Thanks to [@duongna211](https://github.com/duongna21) it's possible to fine-tune Stable Diffusion using Flax! This is very efficient on TPU hardware but works great on GPUs too. You can use the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py) like this:
|
||||
|
||||
```Python
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export dataset_name="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
@@ -136,3 +144,77 @@ python train_text_to_image_flax.py \
|
||||
--max_grad_norm=1 \
|
||||
--output_dir="sd-pokemon-model"
|
||||
```
|
||||
|
||||
To finetune on your own dataset, prepare the dataset according to the format required by 🤗 [Datasets](https://huggingface.co/docs/datasets/index). You can [upload your dataset to the Hub](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub), or you can [prepare a local folder with your files](https://huggingface.co/docs/datasets/image_dataset#imagefolder).
|
||||
|
||||
Modify the script if you want to use custom loading logic. We left pointers in the code in the appropriate places to help you. 🤗 The example script below shows how to finetune on a local dataset in `TRAIN_DIR`:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export TRAIN_DIR="path_to_your_dataset"
|
||||
|
||||
python train_text_to_image_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$TRAIN_DIR \
|
||||
--resolution=512 --center_crop --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--mixed_precision="fp16" \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--output_dir="sd-pokemon-model"
|
||||
```
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
## LoRA
|
||||
|
||||
You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, for fine-tuning text-to-image models. For more details, take a look at the [LoRA training](lora#text-to-image) guide.
|
||||
|
||||
## Inference
|
||||
|
||||
Now you can load the fine-tuned model for inference by passing the model path or model name on the Hub to the [`StableDiffusionPipeline`]:
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
model_path = "path_to_saved_model"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(prompt="yoda").images[0]
|
||||
image.save("yoda-pokemon.png")
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
```python
|
||||
import jax
|
||||
import numpy as np
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
from diffusers import FlaxStableDiffusionPipeline
|
||||
|
||||
model_path = "path_to_saved_model"
|
||||
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
|
||||
|
||||
prompt = "yoda pokemon"
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||
image.save("yoda-pokemon.png")
|
||||
```
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
@@ -14,74 +14,85 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Textual Inversion
|
||||
|
||||
Textual Inversion is a technique for capturing novel concepts from a small number of example images in a way that can later be used to control text-to-image pipelines. It does so by learning new 'words' in the embedding space of the pipeline's text encoder. These special words can then be used within text prompts to achieve very fine-grained control of the resulting images.
|
||||
[[open-in-colab]]
|
||||
|
||||
[Textual Inversion](https://arxiv.org/abs/2208.01618) is a technique for capturing novel concepts from a small number of example images. While the technique was originally demonstrated with a [latent diffusion model](https://github.com/CompVis/latent-diffusion), it has since been applied to other model variants like [Stable Diffusion](https://huggingface.co/docs/diffusers/main/en/conceptual/stable_diffusion). The learned concepts can be used to better control the images generated from text-to-image pipelines. It learns new "words" in the text encoder's embedding space, which are used within text prompts for personalized image generation.
|
||||
|
||||

|
||||
_By using just 3-5 images you can teach new concepts to a model such as Stable Diffusion for personalized image generation ([image source](https://github.com/rinongal/textual_inversion))._
|
||||
<small>By using just 3-5 images you can teach new concepts to a model such as Stable Diffusion for personalized image generation <a href="https://github.com/rinongal/textual_inversion">(image source)</a></small>
|
||||
|
||||
This technique was introduced in [An Image is Worth One Word: Personalizing Text-to-Image Generation using Textual Inversion](https://arxiv.org/abs/2208.01618). The paper demonstrated the concept using a [latent diffusion model](https://github.com/CompVis/latent-diffusion) but the idea has since been applied to other variants such as [Stable Diffusion](https://huggingface.co/docs/diffusers/main/en/conceptual/stable_diffusion).
|
||||
This guide will show you how to train a [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) model with Textual Inversion. All the training scripts for Textual Inversion used in this guide can be found [here](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) if you're interested in taking a closer look at how things work under the hood.
|
||||
|
||||
<Tip>
|
||||
|
||||
## How It Works
|
||||
There is a community-created collection of trained Textual Inversion models in the [Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library) which are readily available for inference. Over time, this'll hopefully grow into a useful resource as more concepts are added!
|
||||
|
||||

|
||||
_Architecture Overview from the [textual inversion blog post](https://textual-inversion.github.io/)_
|
||||
</Tip>
|
||||
|
||||
Before a text prompt can be used in a diffusion model, it must first be processed into a numerical representation. This typically involves tokenizing the text, converting each token to an embedding and then feeding those embeddings through a model (typically a transformer) whose output will be used as the conditioning for the diffusion model.
|
||||
|
||||
Textual inversion learns a new token embedding (v* in the diagram above). A prompt (that includes a token which will be mapped to this new embedding) is used in conjunction with a noised version of one or more training images as inputs to the generator model, which attempts to predict the denoised version of the image. The embedding is optimized based on how well the model does at this task - an embedding that better captures the object or style shown by the training images will give more useful information to the diffusion model and thus result in a lower denoising loss. After many steps (typically several thousand) with a variety of prompt and image variants the learned embedding should hopefully capture the essence of the new concept being taught.
|
||||
|
||||
## Usage
|
||||
|
||||
To train your own textual inversions, see the [example script here](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion).
|
||||
|
||||
There is also a notebook for training:
|
||||
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
||||
|
||||
And one for inference:
|
||||
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
|
||||
|
||||
In addition to using concepts you have trained yourself, there is a community-created collection of trained textual inversions in the new [Stable Diffusion public concepts library](https://huggingface.co/sd-concepts-library) which you can also use from the inference notebook above. Over time this will hopefully grow into a useful resource as more examples are added.
|
||||
|
||||
## Example: Running locally
|
||||
|
||||
The `textual_inversion.py` script [here](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion) shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies.
|
||||
Before you begin, make sure you install the library's training dependencies:
|
||||
|
||||
```bash
|
||||
pip install diffusers[training] accelerate transformers
|
||||
pip install diffusers accelerate transformers
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
After all the dependencies have been set up, initialize a [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
|
||||
### Cat toy example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
Run the following command to authenticate your token
|
||||
To setup a default 🤗 Accelerate environment without choosing any configurations:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
If you have already cloned the repo, then you won't need to go through these steps.
|
||||
Or if your environment doesn't support an interactive shell like a notebook, you can use:
|
||||
|
||||
<br>
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
And launch the training using
|
||||
Finally, you try and [install xFormers](https://huggingface.co/docs/diffusers/main/en/training/optimization/xformers) to reduce your memory footprint with xFormers memory-efficient attention. Once you have xFormers installed, add the `--enable_xformers_memory_efficient_attention` argument to the training script. xFormers is not supported for Flax.
|
||||
|
||||
## Upload model to Hub
|
||||
|
||||
If you want to store your model on the Hub, add the following argument to the training script:
|
||||
|
||||
```bash
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Save and load checkpoints
|
||||
|
||||
It is often a good idea to regularly save checkpoints of your model during training. This way, you can resume training from a saved checkpoint if your training is interrupted for any reason. To save a checkpoint, pass the following argument to the training script to save the full training state in a subfolder in `output_dir` every 500 steps:
|
||||
|
||||
```bash
|
||||
--checkpointing_steps=500
|
||||
```
|
||||
|
||||
To resume training from a saved checkpoint, pass the following argument to the training script and the specific checkpoint you'd like to resume from:
|
||||
|
||||
```bash
|
||||
--resume_from_checkpoint="checkpoint-1500"
|
||||
```
|
||||
|
||||
## Finetuning
|
||||
|
||||
For your training dataset, download these [images of a cat statue](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and store them in a directory.
|
||||
|
||||
Set the `MODEL_NAME` environment variable to the model repository id, and the `DATA_DIR` environment variable to the path of the directory containing the images. Now you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py):
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 A full training run takes ~1 hour on one V100 GPU. While you're waiting for the training to complete, feel free to check out [how Textual Inversion works](#how-it-works) in the section below if you're curious!
|
||||
|
||||
</Tip>
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export DATA_DIR="path-to-dir-containing-images"
|
||||
@@ -100,14 +111,56 @@ accelerate launch textual_inversion.py \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="textual_inversion_cat"
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
If you have access to TPUs, try out the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py) to train even faster (this'll also work for GPUs). With the same configuration settings, the Flax training script should be at least 70% faster than the PyTorch training script! ⚡️
|
||||
|
||||
A full training run takes ~1 hour on one V100 GPU.
|
||||
Before you begin, make sure you install the Flax specific dependencies:
|
||||
|
||||
```bash
|
||||
pip install -U -r requirements_flax.txt
|
||||
```
|
||||
|
||||
### Inference
|
||||
Then you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py):
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export DATA_DIR="path-to-dir-containing-images"
|
||||
|
||||
python textual_inversion_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$DATA_DIR \
|
||||
--learnable_property="object" \
|
||||
--placeholder_token="<cat-toy>" --initializer_token="toy" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--max_train_steps=3000 \
|
||||
--learning_rate=5.0e-04 --scale_lr \
|
||||
--output_dir="textual_inversion_cat"
|
||||
```
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
### Intermediate logging
|
||||
|
||||
If you're interested in following along with your model training progress, you can save the generated images from the training process. Add the following arguments to the training script to enable intermediate logging:
|
||||
|
||||
- `validation_prompt`, the prompt used to generate samples (this is set to `None` by default and intermediate logging is disabled)
|
||||
- `num_validation_images`, the number of sample images to generate
|
||||
- `validation_steps`, the number of steps before generating `num_validation_images` from the `validation_prompt`
|
||||
|
||||
```bash
|
||||
--validation_prompt="A <cat-toy> backpack"
|
||||
--num_validation_images=4
|
||||
--validation_steps=100
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a model, you can use it for inference with the [`StableDiffusionPipeline]. Make sure you include the `placeholder_token` in your prompt, in this case, it is `<cat-toy>`.
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
@@ -120,3 +173,43 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
|
||||
image.save("cat-backpack.png")
|
||||
```
|
||||
</pt>
|
||||
<jax>
|
||||
```python
|
||||
import jax
|
||||
import numpy as np
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
from diffusers import FlaxStableDiffusionPipeline
|
||||
|
||||
model_path = "path-to-your-trained-model"
|
||||
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
|
||||
|
||||
prompt = "A <cat-toy> backpack"
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 50
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||
image.save("cat-backpack.png")
|
||||
```
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
## How it works
|
||||
|
||||

|
||||
<small>Architecture overview from the Textual Inversion <a href="https://textual-inversion.github.io/">blog post.</a></small>
|
||||
|
||||
Usually, text prompts are tokenized into an embedding before being passed to a model, which is often a transformer. Textual Inversion does something similar, but it learns a new token embedding, `v*`, from a special token `S*` in the diagram above. The model output is used to condition the diffusion model, which helps the diffusion model understand the prompt and new concepts from just a few example images.
|
||||
|
||||
To do this, Textual Inversion uses a generator model and noisy versions of the training images. The generator tries to predict less noisy versions of the images, and the token embedding `v*` is optimized based on how well the generator does. If the token embedding successfully captures the new concept, it gives more useful information to the diffusion model and helps create clearer images with less noise. This optimization process typically occurs after several thousand steps of exposure to a variety of prompt and image variants.
|
||||
@@ -0,0 +1,414 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Train a diffusion model
|
||||
|
||||
Unconditional image generation is a popular application of diffusion models that generates images that look like those in the dataset used for training. Typically, the best results are obtained from finetuning a pretrained model on a specific dataset. You can find many of these checkpoints on the [Hub](https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model), but if you can't find one you like, you can always train your own!
|
||||
|
||||
This tutorial will teach you how to train a [`UNet2DModel`] from scratch on a subset of the [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) dataset to generate your own 🦋 butterflies 🦋.
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 This training tutorial is based on the [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) notebook. For additional details and context about diffusion models like how they work, check out the notebook!
|
||||
|
||||
</Tip>
|
||||
|
||||
Before you begin, make sure you have 🤗 Datasets installed to load and preprocess image datasets, and 🤗 Accelerate, to simplify training on any number of GPUs. The following command will also install [TensorBoard](https://www.tensorflow.org/tensorboard) to visualize training metrics (you can also use [Weights & Biases](https://docs.wandb.ai/) to track your training).
|
||||
|
||||
```bash
|
||||
!pip install diffusers[training]
|
||||
```
|
||||
|
||||
We encourage you to share your model with the community, and in order to do that, you'll need to login to your Hugging Face account (create one [here](https://hf.co/join) if you don't already have one!). You can login from a notebook and enter your token when prompted:
|
||||
|
||||
```py
|
||||
>>> from huggingface_hub import notebook_login
|
||||
|
||||
>>> notebook_login()
|
||||
```
|
||||
|
||||
Or login in from the terminal:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
Since the model checkpoints are quite large, install [Git-LFS](https://git-lfs.com/) to version these large files:
|
||||
|
||||
```bash
|
||||
!sudo apt -qq install git-lfs
|
||||
!git config --global credential.helper store
|
||||
```
|
||||
|
||||
## Training configuration
|
||||
|
||||
For convenience, create a `TrainingConfig` class containing the training hyperparameters (feel free to adjust them):
|
||||
|
||||
```py
|
||||
>>> from dataclasses import dataclass
|
||||
|
||||
|
||||
>>> @dataclass
|
||||
... class TrainingConfig:
|
||||
... image_size = 128 # the generated image resolution
|
||||
... train_batch_size = 16
|
||||
... eval_batch_size = 16 # how many images to sample during evaluation
|
||||
... num_epochs = 50
|
||||
... gradient_accumulation_steps = 1
|
||||
... learning_rate = 1e-4
|
||||
... lr_warmup_steps = 500
|
||||
... save_image_epochs = 10
|
||||
... save_model_epochs = 30
|
||||
... mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
|
||||
... output_dir = "ddpm-butterflies-128" # the model name locally and on the HF Hub
|
||||
|
||||
... push_to_hub = True # whether to upload the saved model to the HF Hub
|
||||
... hub_private_repo = False
|
||||
... overwrite_output_dir = True # overwrite the old model when re-running the notebook
|
||||
... seed = 0
|
||||
|
||||
|
||||
>>> config = TrainingConfig()
|
||||
```
|
||||
|
||||
## Load the dataset
|
||||
|
||||
You can easily load the [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) dataset with the 🤗 Datasets library:
|
||||
|
||||
```py
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> config.dataset_name = "huggan/smithsonian_butterflies_subset"
|
||||
>>> dataset = load_dataset(config.dataset_name, split="train")
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 You can find additional datasets from the [HugGan Community Event](https://huggingface.co/huggan) or you can use your own dataset by creating a local [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder). Set `config.dataset_name` to the repository id of the dataset if it is from the HugGan Community Event, or `imagefolder` if you're using your own images.
|
||||
|
||||
</Tip>
|
||||
|
||||
🤗 Datasets uses the [`~datasets.Image`] feature to automatically decode the image data and load it as a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html) which we can visualize:
|
||||
|
||||
```py
|
||||
>>> import matplotlib.pyplot as plt
|
||||
|
||||
>>> fig, axs = plt.subplots(1, 4, figsize=(16, 4))
|
||||
>>> for i, image in enumerate(dataset[:4]["image"]):
|
||||
... axs[i].imshow(image)
|
||||
... axs[i].set_axis_off()
|
||||
>>> fig.show()
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_ds.png"/>
|
||||
</div>
|
||||
|
||||
The images are all different sizes though, so you'll need to preprocess them first:
|
||||
|
||||
* `Resize` changes the image size to the one defined in `config.image_size`.
|
||||
* `RandomHorizontalFlip` augments the dataset by randomly mirroring the images.
|
||||
* `Normalize` is important to rescale the pixel values into a [-1, 1] range, which is what the model expects.
|
||||
|
||||
```py
|
||||
>>> from torchvision import transforms
|
||||
|
||||
>>> preprocess = transforms.Compose(
|
||||
... [
|
||||
... transforms.Resize((config.image_size, config.image_size)),
|
||||
... transforms.RandomHorizontalFlip(),
|
||||
... transforms.ToTensor(),
|
||||
... transforms.Normalize([0.5], [0.5]),
|
||||
... ]
|
||||
... )
|
||||
```
|
||||
|
||||
Use 🤗 Datasets' [`~datasets.Dataset.set_transform`] method to apply the `preprocess` function on the fly during training:
|
||||
|
||||
```py
|
||||
>>> def transform(examples):
|
||||
... images = [preprocess(image.convert("RGB")) for image in examples["image"]]
|
||||
... return {"images": images}
|
||||
|
||||
|
||||
>>> dataset.set_transform(transform)
|
||||
```
|
||||
|
||||
Feel free to visualize the images again to confirm that they've been resized. Now you're ready to wrap the dataset in a [DataLoader](https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader) for training!
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
|
||||
>>> train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
|
||||
```
|
||||
|
||||
## Create a UNet2DModel
|
||||
|
||||
Pretrained models in 🧨 Diffusers are easily created from their model class with the parameters you want. For example, to create a [`UNet2DModel`]:
|
||||
|
||||
```py
|
||||
>>> from diffusers import UNet2DModel
|
||||
|
||||
>>> model = UNet2DModel(
|
||||
... sample_size=config.image_size, # the target image resolution
|
||||
... in_channels=3, # the number of input channels, 3 for RGB images
|
||||
... out_channels=3, # the number of output channels
|
||||
... layers_per_block=2, # how many ResNet layers to use per UNet block
|
||||
... block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block
|
||||
... down_block_types=(
|
||||
... "DownBlock2D", # a regular ResNet downsampling block
|
||||
... "DownBlock2D",
|
||||
... "DownBlock2D",
|
||||
... "DownBlock2D",
|
||||
... "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
|
||||
... "DownBlock2D",
|
||||
... ),
|
||||
... up_block_types=(
|
||||
... "UpBlock2D", # a regular ResNet upsampling block
|
||||
... "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
|
||||
... "UpBlock2D",
|
||||
... "UpBlock2D",
|
||||
... "UpBlock2D",
|
||||
... "UpBlock2D",
|
||||
... ),
|
||||
... )
|
||||
```
|
||||
|
||||
It is often a good idea to quickly check the sample image shape matches the model output shape:
|
||||
|
||||
```py
|
||||
>>> sample_image = dataset[0]["images"].unsqueeze(0)
|
||||
>>> print("Input shape:", sample_image.shape)
|
||||
Input shape: torch.Size([1, 3, 128, 128])
|
||||
|
||||
>>> print("Output shape:", model(sample_image, timestep=0).sample.shape)
|
||||
Output shape: torch.Size([1, 3, 128, 128])
|
||||
```
|
||||
|
||||
Great! Next, you'll need a scheduler to add some noise to the image.
|
||||
|
||||
## Create a scheduler
|
||||
|
||||
The scheduler behaves differently depending on whether you're using the model for training or inference. During inference, the scheduler generates image from the noise. During training, the scheduler takes a model output - or a sample - from a specific point in the diffusion process and applies noise to the image according to a *noise schedule* and an *update rule*.
|
||||
|
||||
Let's take a look at the [`DDPMScheduler`] and use the `add_noise` method to add some random noise to the `sample_image` from before:
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import DDPMScheduler
|
||||
|
||||
>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
|
||||
>>> noise = torch.randn(sample_image.shape)
|
||||
>>> timesteps = torch.LongTensor([50])
|
||||
>>> noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)
|
||||
|
||||
>>> Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/noisy_butterfly.png"/>
|
||||
</div>
|
||||
|
||||
The training objective of the model is to predict the noise added to the image. The loss at this step can be calculated by:
|
||||
|
||||
```py
|
||||
>>> import torch.nn.functional as F
|
||||
|
||||
>>> noise_pred = model(noisy_image, timesteps).sample
|
||||
>>> loss = F.mse_loss(noise_pred, noise)
|
||||
```
|
||||
|
||||
## Train the model
|
||||
|
||||
By now, you have most of the pieces to start training the model and all that's left is putting everything together.
|
||||
|
||||
First, you'll need an optimizer and a learning rate scheduler:
|
||||
|
||||
```py
|
||||
>>> from diffusers.optimization import get_cosine_schedule_with_warmup
|
||||
|
||||
>>> optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||
>>> lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
... optimizer=optimizer,
|
||||
... num_warmup_steps=config.lr_warmup_steps,
|
||||
... num_training_steps=(len(train_dataloader) * config.num_epochs),
|
||||
... )
|
||||
```
|
||||
|
||||
Then, you'll need a way to evaluate the model. For evaluation, you can use the [`DDPMPipeline`] to generate a batch of sample images and save it as a grid:
|
||||
|
||||
```py
|
||||
>>> from diffusers import DDPMPipeline
|
||||
>>> import math
|
||||
|
||||
|
||||
>>> def make_grid(images, rows, cols):
|
||||
... w, h = images[0].size
|
||||
... grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
... for i, image in enumerate(images):
|
||||
... grid.paste(image, box=(i % cols * w, i // cols * h))
|
||||
... return grid
|
||||
|
||||
|
||||
>>> def evaluate(config, epoch, pipeline):
|
||||
... # Sample some images from random noise (this is the backward diffusion process).
|
||||
... # The default pipeline output type is `List[PIL.Image]`
|
||||
... images = pipeline(
|
||||
... batch_size=config.eval_batch_size,
|
||||
... generator=torch.manual_seed(config.seed),
|
||||
... ).images
|
||||
|
||||
... # Make a grid out of the images
|
||||
... image_grid = make_grid(images, rows=4, cols=4)
|
||||
|
||||
... # Save the images
|
||||
... test_dir = os.path.join(config.output_dir, "samples")
|
||||
... os.makedirs(test_dir, exist_ok=True)
|
||||
... image_grid.save(f"{test_dir}/{epoch:04d}.png")
|
||||
```
|
||||
|
||||
Now you can wrap all these components together in a training loop with 🤗 Accelerate for easy TensorBoard logging, gradient accumulation, and mixed precision training. To upload the model to the Hub, write a function to get your repository name and information and then push it to the Hub.
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 The training loop below may look intimidating and long, but it'll be worth it later when you launch your training in just one line of code! If you can't wait and want to start generating images, feel free to copy and run the code below. You can always come back and examine the training loop more closely later, like when you're waiting for your model to finish training. 🤗
|
||||
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
>>> from accelerate import Accelerator
|
||||
>>> from huggingface_hub import HfFolder, Repository, whoami
|
||||
>>> from tqdm.auto import tqdm
|
||||
>>> from pathlib import Path
|
||||
>>> import os
|
||||
|
||||
|
||||
>>> def get_full_repo_name(model_id: str, organization: str = None, token: str = None):
|
||||
... if token is None:
|
||||
... token = HfFolder.get_token()
|
||||
... if organization is None:
|
||||
... username = whoami(token)["name"]
|
||||
... return f"{username}/{model_id}"
|
||||
... else:
|
||||
... return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
|
||||
... # Initialize accelerator and tensorboard logging
|
||||
... accelerator = Accelerator(
|
||||
... mixed_precision=config.mixed_precision,
|
||||
... gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
... log_with="tensorboard",
|
||||
... logging_dir=os.path.join(config.output_dir, "logs"),
|
||||
... )
|
||||
... if accelerator.is_main_process:
|
||||
... if config.push_to_hub:
|
||||
... repo_name = get_full_repo_name(Path(config.output_dir).name)
|
||||
... repo = Repository(config.output_dir, clone_from=repo_name)
|
||||
... elif config.output_dir is not None:
|
||||
... os.makedirs(config.output_dir, exist_ok=True)
|
||||
... accelerator.init_trackers("train_example")
|
||||
|
||||
... # Prepare everything
|
||||
... # There is no specific order to remember, you just need to unpack the
|
||||
... # objects in the same order you gave them to the prepare method.
|
||||
... model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
... model, optimizer, train_dataloader, lr_scheduler
|
||||
... )
|
||||
|
||||
... global_step = 0
|
||||
|
||||
... # Now you train the model
|
||||
... for epoch in range(config.num_epochs):
|
||||
... progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
|
||||
... progress_bar.set_description(f"Epoch {epoch}")
|
||||
|
||||
... for step, batch in enumerate(train_dataloader):
|
||||
... clean_images = batch["images"]
|
||||
... # Sample noise to add to the images
|
||||
... noise = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
... bs = clean_images.shape[0]
|
||||
|
||||
... # Sample a random timestep for each image
|
||||
... timesteps = torch.randint(
|
||||
... 0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
|
||||
... ).long()
|
||||
|
||||
... # Add noise to the clean images according to the noise magnitude at each timestep
|
||||
... # (this is the forward diffusion process)
|
||||
... noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
|
||||
|
||||
... with accelerator.accumulate(model):
|
||||
... # Predict the noise residual
|
||||
... noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
|
||||
... loss = F.mse_loss(noise_pred, noise)
|
||||
... accelerator.backward(loss)
|
||||
|
||||
... accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
... optimizer.step()
|
||||
... lr_scheduler.step()
|
||||
... optimizer.zero_grad()
|
||||
|
||||
... progress_bar.update(1)
|
||||
... logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
||||
... progress_bar.set_postfix(**logs)
|
||||
... accelerator.log(logs, step=global_step)
|
||||
... global_step += 1
|
||||
|
||||
... # After each epoch you optionally sample some demo images with evaluate() and save the model
|
||||
... if accelerator.is_main_process:
|
||||
... pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||
|
||||
... if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
|
||||
... evaluate(config, epoch, pipeline)
|
||||
|
||||
... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
|
||||
... if config.push_to_hub:
|
||||
... repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True)
|
||||
... else:
|
||||
... pipeline.save_pretrained(config.output_dir)
|
||||
```
|
||||
|
||||
Phew, that was quite a bit of code! But you're finally ready to launch the training with 🤗 Accelerate's [`~accelerate.notebook_launcher`] function. Pass the function the training loop, all the training arguments, and the number of processes (you can change this value to the number of GPUs available to you) to use for training:
|
||||
|
||||
```py
|
||||
>>> from accelerate import notebook_launcher
|
||||
|
||||
>>> args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
>>> notebook_launcher(train_loop, args, num_processes=1)
|
||||
```
|
||||
|
||||
Once training is complete, take a look at the final 🦋 images 🦋 generated by your diffusion model!
|
||||
|
||||
```py
|
||||
>>> import glob
|
||||
|
||||
>>> sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
|
||||
>>> Image.open(sample_images[-1])
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_final.png"/>
|
||||
</div>
|
||||
|
||||
## Next steps
|
||||
|
||||
Unconditional image generation is one example of a task that can be trained. You can explore other tasks and training techniques by visiting the [🧨 Diffusers Training Examples](./training/overview) page. Here are some examples of what you can learn:
|
||||
|
||||
* [Textual Inversion](./training/text_inversion), an algorithm that teaches a model a specific visual concept and integrates it into the generated image.
|
||||
* [DreamBooth](./training/dreambooth), a technique for generating personalized images of a subject given several input images of the subject.
|
||||
* [Guide](./training/text2image) to finetuning a Stable Diffusion model on your own dataset.
|
||||
* [Guide](./training/lora) to using LoRA, a memory-efficient technique for finetuning really large models faster.
|
||||
@@ -36,6 +36,7 @@ Unless otherwise mentioned, these are techniques that work with existing models
|
||||
8. [DreamBooth](#dreambooth)
|
||||
9. [Textual Inversion](#textual-inversion)
|
||||
10. [ControlNet](#controlnet)
|
||||
11. [Prompt Weighting](#prompt-weighting)
|
||||
|
||||
## Instruct Pix2Pix
|
||||
|
||||
@@ -158,3 +159,9 @@ depth maps, and semantic segmentations.
|
||||
|
||||
See [here](../api/pipelines/stable_diffusion/controlnet) for more information on how to use it.
|
||||
|
||||
## Prompt Weighting
|
||||
|
||||
Prompt weighting is a simple technique that puts more attention weight on certain parts of the text
|
||||
input.
|
||||
|
||||
For a more in-detail explanation and examples, see [here](../using-diffusers/weighted_prompts).
|
||||
|
||||
@@ -12,7 +12,17 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Text-Guided Image-to-Image Generation
|
||||
|
||||
The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images.
|
||||
[[open-in-colab]]
|
||||
|
||||
The [`StableDiffusionImg2ImgPipeline`] lets you pass a text prompt and an initial image to condition the generation of new images. This tutorial shows how to use it for text-guided image-to-image generation with Stable Diffusion model.
|
||||
|
||||
Before you begin, make sure you have all the necessary libraries installed:
|
||||
|
||||
```bash
|
||||
!pip install diffusers transformers ftfy accelerate
|
||||
```
|
||||
|
||||
Get started by creating a [`StableDiffusionImg2ImgPipeline`] with a pretrained Stable Diffusion model.
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -21,25 +31,83 @@ from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
```
|
||||
|
||||
# load the pipeline
|
||||
Load the pipeline
|
||||
|
||||
```python
|
||||
device = "cuda"
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
|
||||
device
|
||||
)
|
||||
```
|
||||
|
||||
# let's download an initial image
|
||||
Download an initial image and preprocess it so we can pass it to the pipeline.
|
||||
|
||||
```python
|
||||
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image.thumbnail((768, 768))
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
|
||||
|
||||
images[0].save("fantasy_landscape.png")
|
||||
init_image
|
||||
```
|
||||
You can also run this example on colab [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
|
||||

|
||||
|
||||
Define the prompt and run the pipeline.
|
||||
|
||||
```python
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
`strength` is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input.
|
||||
|
||||
</Tip>
|
||||
|
||||
Let's generate two images with same pipeline and seed, but with different values for `strength`
|
||||
|
||||
```python
|
||||
generator = torch.Generator(device=device).manual_seed(1024)
|
||||
image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
|
||||
```
|
||||
|
||||
```python
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
```python
|
||||
image = pipe(prompt=prompt, image=init_image, strength=0.5, guidance_scale=7.5, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
As you can see, when using a lower value for `strength`, the generated image is more closer to the original `image`
|
||||
|
||||
Now let's use a different scheduler - [LMSDiscreteScheduler](https://huggingface.co/docs/diffusers/api/schedulers#diffusers.LMSDiscreteScheduler)
|
||||
|
||||
```python
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
lms = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.scheduler = lms
|
||||
```
|
||||
|
||||
```python
|
||||
generator = torch.Generator(device=device).manual_seed(1024)
|
||||
image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5, generator=generator).images[0]
|
||||
```
|
||||
|
||||
```python
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Weighting prompts
|
||||
|
||||
Text-guided diffusion models generate images based on a given text prompt. The text prompt
|
||||
can include multiple concepts that the model should generate and it's often desirable to weight
|
||||
certain parts of the prompt more or less.
|
||||
|
||||
Diffusion models work by conditioning the cross attention layers of the diffusion model with contextualized text embeddings (see the [Stable Diffusion Guide for more information](../stable-diffusion)).
|
||||
Thus a simple way to emphasize (or de-emphasize) certain parts of the prompt is by increasing or reducing the scale of the text embedding vector that corresponds to the relevant part of the prompt.
|
||||
This is called "prompt-weighting" and has been a highly demanded feature by the community (see issue [here](https://github.com/huggingface/diffusers/issues/2431)).
|
||||
|
||||
## How to do prompt-weighting in Diffusers
|
||||
|
||||
We believe the role of `diffusers` is to be a toolbox that provides essential features that enable other projects, such as [InvokeAI](https://github.com/invoke-ai/InvokeAI) or [diffuzers](https://github.com/abhishekkrthakur/diffuzers), to build powerful UIs. In order to support arbitrary methods to manipulate prompts, `diffusers` exposes a [`prompt_embeds`](https://huggingface.co/docs/diffusers/v0.14.0/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) function argument to many pipelines such as [`StableDiffusionPipeline`], allowing to directly pass the "prompt-weighted"/scaled text embeddings to the pipeline.
|
||||
|
||||
The [compel library](https://github.com/damian0815/compel) provides an easy way to emphasize or de-emphasize portions of the prompt for you. We strongly recommend it instead of preparing the embeddings yourself.
|
||||
|
||||
Let's look at a simple example. Imagine you want to generate an image of `"a red cat playing with a ball"` as
|
||||
follows:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
prompt = "a red cat playing with a ball"
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
This gives you:
|
||||
|
||||

|
||||
|
||||
As you can see, there is no "ball" in the image. Let's emphasize this part!
|
||||
|
||||
For this we should install the `compel` library:
|
||||
|
||||
```
|
||||
pip install compel
|
||||
```
|
||||
|
||||
and then create a `Compel` object:
|
||||
|
||||
```py
|
||||
from compel import Compel
|
||||
|
||||
compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
|
||||
```
|
||||
|
||||
Now we emphasize the part "ball" with the `"++"` syntax:
|
||||
|
||||
```py
|
||||
prompt = "a red cat playing with a ball++"
|
||||
```
|
||||
|
||||
and instead of passing this to the pipeline directly, we have to process it using `compel_proc`:
|
||||
|
||||
```py
|
||||
prompt_embeds = compel_proc(prompt)
|
||||
```
|
||||
|
||||
Now we can pass `prompt_embeds` directly to the pipeline:
|
||||
|
||||
```py
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
|
||||
images = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
We now get the following image which has a "ball"!
|
||||
|
||||

|
||||
|
||||
Similarly, we de-emphasize parts of the sentence by using the `--` suffix for words, feel free to give it
|
||||
a try!
|
||||
|
||||
If your favorite pipeline does not have a `prompt_embeds` input, please make sure to open an issue, the
|
||||
diffusers team tries to be as responsive as possible.
|
||||
|
||||
Also, please check out the documentation of the [compel](https://github.com/damian0815/compel) library for
|
||||
more information.
|
||||
@@ -28,6 +28,7 @@ Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Di
|
||||
MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
|
||||
| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - |[Ray Wang](https://wrong.wang) |
|
||||
| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
|
||||
|
||||
|
||||
@@ -989,3 +990,47 @@ The resulting images in order:-
|
||||

|
||||

|
||||

|
||||
|
||||
### UnCLIP Image Interpolation Pipeline
|
||||
|
||||
This Diffusion Pipeline takes two images or an image_embeddings tensor of size 2 and interpolates between their embeddings using spherical interpolation ( slerp ). The input images/image_embeddings are converted to image embeddings by the pipeline's image_encoder and the interpolation is done on the resulting image_embeddings over the number of steps specified. Defaults to 5 steps.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
|
||||
dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"kakaobrain/karlo-v1-alpha-image-variations",
|
||||
torch_dtype=dtype,
|
||||
custom_pipeline="unclip_image_interpolation"
|
||||
)
|
||||
pipe.to(device)
|
||||
|
||||
images = [Image.open('./starry_night.jpg'), Image.open('./flowers.jpg')]
|
||||
#For best results keep the prompts close in length to each other. Of course, feel free to try out with differing lengths.
|
||||
generator = torch.Generator(device=device).manual_seed(42)
|
||||
|
||||
output = pipe(image = images ,steps = 6, generator = generator)
|
||||
|
||||
for i,image in enumerate(output.images):
|
||||
image.save('starry_to_flowers_%s.jpg' % i)
|
||||
```
|
||||
The original images:-
|
||||
|
||||

|
||||

|
||||
|
||||
The resulting images in order:-
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,894 @@
|
||||
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import ControlNetModel, UniPCMultistepScheduler
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
|
||||
|
||||
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||
|
||||
>>> pipe_controlnet = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
>>> pipe_controlnet.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
|
||||
>>> pipe_controlnet.enable_xformers_memory_efficient_attention()
|
||||
>>> pipe_controlnet.enable_model_cpu_offload()
|
||||
|
||||
# using image with edges for our canny controlnet
|
||||
>>> control_image = load_image(
|
||||
"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_canny_edged.png")
|
||||
|
||||
|
||||
>>> result_img = pipe_controlnet(controlnet_conditioning_image=control_image,
|
||||
image=input_image,
|
||||
prompt="an android robot, cyberpank, digitl art masterpiece",
|
||||
num_inference_steps=20).images[0]
|
||||
|
||||
>>> result_img.show()
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def prepare_image(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Batch single image
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
image = image.to(dtype=torch.float32)
|
||||
else:
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def prepare_controlnet_conditioning_image(
|
||||
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_image, torch.Tensor):
|
||||
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
|
||||
controlnet_conditioning_image = [controlnet_conditioning_image]
|
||||
|
||||
if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
|
||||
controlnet_conditioning_image = [
|
||||
np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]))[None, :]
|
||||
for i in controlnet_conditioning_image
|
||||
]
|
||||
controlnet_conditioning_image = np.concatenate(controlnet_conditioning_image, axis=0)
|
||||
controlnet_conditioning_image = np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
|
||||
controlnet_conditioning_image = controlnet_conditioning_image.transpose(0, 3, 1, 2)
|
||||
controlnet_conditioning_image = torch.from_numpy(controlnet_conditioning_image)
|
||||
elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
|
||||
controlnet_conditioning_image = torch.cat(controlnet_conditioning_image, dim=0)
|
||||
|
||||
image_batch_size = controlnet_conditioning_image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
|
||||
|
||||
return controlnet_conditioning_image
|
||||
|
||||
|
||||
class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
|
||||
"""
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
Note that offloading happens on a submodule basis. Memory savings are higher than with
|
||||
`enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
# the safety checker can offload the vae again
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# control net hook has be manually offloaded as it alternates with unet
|
||||
cpu_offload_with_hook(self.controlnet, device)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
controlnet_conditioning_image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
strength=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
|
||||
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
|
||||
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
|
||||
controlnet_conditioning_image[0], PIL.Image.Image
|
||||
)
|
||||
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
|
||||
controlnet_conditioning_image[0], torch.Tensor
|
||||
)
|
||||
|
||||
if (
|
||||
not controlnet_cond_image_is_pil
|
||||
and not controlnet_cond_image_is_tensor
|
||||
and not controlnet_cond_image_is_pil_list
|
||||
and not controlnet_cond_image_is_tensor_list
|
||||
):
|
||||
raise TypeError(
|
||||
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
|
||||
)
|
||||
|
||||
if controlnet_cond_image_is_pil:
|
||||
controlnet_cond_image_batch_size = 1
|
||||
elif controlnet_cond_image_is_tensor:
|
||||
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
|
||||
elif controlnet_cond_image_is_pil_list:
|
||||
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
|
||||
elif controlnet_cond_image_is_tensor_list:
|
||||
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
prompt_batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
prompt_batch_size = len(prompt)
|
||||
elif prompt_embeds is not None:
|
||||
prompt_batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
|
||||
raise ValueError(
|
||||
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
|
||||
)
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
if image.ndim != 3 and image.ndim != 4:
|
||||
raise ValueError("`image` must have 3 or 4 dimensions")
|
||||
|
||||
# if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
|
||||
# raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
|
||||
|
||||
if image.ndim == 3:
|
||||
image_batch_size = 1
|
||||
image_channels, image_height, image_width = image.shape
|
||||
elif image.ndim == 4:
|
||||
image_batch_size, image_channels, image_height, image_width = image.shape
|
||||
|
||||
if image_channels != 3:
|
||||
raise ValueError("`image` must have 3 channels")
|
||||
|
||||
if image.min() < -1 or image.max() > 1:
|
||||
raise ValueError("`image` should be in range [-1, 1]")
|
||||
|
||||
if self.vae.config.latent_channels != self.unet.config.in_channels:
|
||||
raise ValueError(
|
||||
f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
|
||||
f" latent channels: {self.vae.config.latent_channels},"
|
||||
f" Please verify the config of `pipeline.unet` and the `pipeline.vae`"
|
||||
)
|
||||
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
shape = init_latents.shape
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
latents = init_latents
|
||||
|
||||
return latents
|
||||
|
||||
def _default_height_width(self, height, width, image):
|
||||
if isinstance(image, list):
|
||||
image = image[0]
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[3]
|
||||
|
||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[2]
|
||||
|
||||
width = (width // 8) * 8 # round down to nearest multiple of 8
|
||||
|
||||
return height, width
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.Tensor, PIL.Image.Image] = None,
|
||||
controlnet_conditioning_image: Union[
|
||||
torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
|
||||
] = None,
|
||||
strength: float = 0.8,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.Tensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
|
||||
be masked out with `mask_image` and repainted according to `prompt`.
|
||||
controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
|
||||
also be accepted as an image. The control image is automatically resized to fit the output image.
|
||||
strength (`float`, *optional*):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height, width = self._default_height_width(height, width, controlnet_conditioning_image)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
# mask_image,
|
||||
controlnet_conditioning_image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
strength,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 4. Prepare mask, image, and controlnet_conditioning_image
|
||||
image = prepare_image(image)
|
||||
|
||||
# mask_image = prepare_mask_image(mask_image)
|
||||
|
||||
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
|
||||
controlnet_conditioning_image,
|
||||
width,
|
||||
height,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
self.controlnet.dtype,
|
||||
)
|
||||
|
||||
# masked_image = image * (mask_image < 0.5)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
latents = self.prepare_latents(
|
||||
image,
|
||||
latent_timestep,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
|
||||
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=controlnet_conditioning_image,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
down_block_res_samples = [
|
||||
down_block_res_sample * controlnet_conditioning_scale
|
||||
for down_block_res_sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample *= controlnet_conditioning_scale
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
elif output_type == "pil":
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# 10. Convert to PIL
|
||||
image = self.numpy_to_pil(image)
|
||||
else:
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 9. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,493 @@
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from transformers import (
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
DiffusionPipeline,
|
||||
ImagePipelineOutput,
|
||||
UnCLIPScheduler,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
)
|
||||
from diffusers.pipelines.unclip import UnCLIPTextProjModel
|
||||
from diffusers.utils import is_accelerate_available, logging, randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def slerp(val, low, high):
|
||||
"""
|
||||
Find the interpolation point between the 'low' and 'high' values for the given 'val'. See https://en.wikipedia.org/wiki/Slerp for more details on the topic.
|
||||
"""
|
||||
low_norm = low / torch.norm(low)
|
||||
high_norm = high / torch.norm(high)
|
||||
omega = torch.acos((low_norm * high_norm))
|
||||
so = torch.sin(omega)
|
||||
res = (torch.sin((1.0 - val) * omega) / so) * low + (torch.sin(val * omega) / so) * high
|
||||
return res
|
||||
|
||||
|
||||
class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline to generate variations from an input image using unCLIP
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
Frozen text-encoder.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
|
||||
image_encoder ([`CLIPVisionModelWithProjection`]):
|
||||
Frozen CLIP image-encoder. unCLIP Image Variation uses the vision portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
text_proj ([`UnCLIPTextProjModel`]):
|
||||
Utility class to prepare and combine the embeddings before they are passed to the decoder.
|
||||
decoder ([`UNet2DConditionModel`]):
|
||||
The decoder to invert the image embedding into an image.
|
||||
super_res_first ([`UNet2DModel`]):
|
||||
Super resolution unet. Used in all but the last step of the super resolution diffusion process.
|
||||
super_res_last ([`UNet2DModel`]):
|
||||
Super resolution unet. Used in the last step of the super resolution diffusion process.
|
||||
decoder_scheduler ([`UnCLIPScheduler`]):
|
||||
Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
|
||||
super_res_scheduler ([`UnCLIPScheduler`]):
|
||||
Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
|
||||
|
||||
"""
|
||||
|
||||
decoder: UNet2DConditionModel
|
||||
text_proj: UnCLIPTextProjModel
|
||||
text_encoder: CLIPTextModelWithProjection
|
||||
tokenizer: CLIPTokenizer
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
image_encoder: CLIPVisionModelWithProjection
|
||||
super_res_first: UNet2DModel
|
||||
super_res_last: UNet2DModel
|
||||
|
||||
decoder_scheduler: UnCLIPScheduler
|
||||
super_res_scheduler: UnCLIPScheduler
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.__init__
|
||||
def __init__(
|
||||
self,
|
||||
decoder: UNet2DConditionModel,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_proj: UnCLIPTextProjModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
super_res_first: UNet2DModel,
|
||||
super_res_last: UNet2DModel,
|
||||
decoder_scheduler: UnCLIPScheduler,
|
||||
super_res_scheduler: UnCLIPScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
decoder=decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_proj=text_proj,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
super_res_first=super_res_first,
|
||||
super_res_last=super_res_last,
|
||||
decoder_scheduler=decoder_scheduler,
|
||||
super_res_scheduler=super_res_scheduler,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
text_mask = text_inputs.attention_mask.bool().to(device)
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(device))
|
||||
|
||||
prompt_embeds = text_encoder_output.text_embeds
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens = [""] * batch_size
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_text_mask = uncond_input.attention_mask.bool().to(device)
|
||||
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
|
||||
uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# done duplicates
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
||||
|
||||
text_mask = torch.cat([uncond_text_mask, text_mask])
|
||||
|
||||
return prompt_embeds, text_encoder_hidden_states, text_mask
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline._encode_image
|
||||
def _encode_image(self, image, device, num_images_per_prompt, image_embeddings: Optional[torch.Tensor] = None):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if image_embeddings is None:
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(images=image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeddings = self.image_encoder(image).image_embeds
|
||||
|
||||
image_embeddings = image_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
return image_embeddings
|
||||
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip_image_variation.UnCLIPImageVariationPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
|
||||
models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
|
||||
when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
models = [
|
||||
self.decoder,
|
||||
self.text_proj,
|
||||
self.text_encoder,
|
||||
self.super_res_first,
|
||||
self.super_res_last,
|
||||
]
|
||||
for cpu_offloaded_model in models:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.decoder.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None,
|
||||
steps: int = 5,
|
||||
decoder_num_inference_steps: int = 25,
|
||||
super_res_num_inference_steps: int = 7,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
image_embeddings: Optional[torch.Tensor] = None,
|
||||
decoder_latents: Optional[torch.FloatTensor] = None,
|
||||
super_res_latents: Optional[torch.FloatTensor] = None,
|
||||
decoder_guidance_scale: float = 8.0,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`List[PIL.Image.Image]` or `torch.FloatTensor`):
|
||||
The images to use for the image interpolation. Only accepts a list of two PIL Images or If you provide a tensor, it needs to comply with the
|
||||
configuration of
|
||||
[this](https://huggingface.co/fusing/karlo-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
|
||||
`CLIPFeatureExtractor` while still having a shape of two in the 0th dimension. Can be left to `None` only when `image_embeddings` are passed.
|
||||
steps (`int`, *optional*, defaults to 5):
|
||||
The number of interpolation images to generate.
|
||||
decoder_num_inference_steps (`int`, *optional*, defaults to 25):
|
||||
The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
|
||||
image at the expense of slower inference.
|
||||
super_res_num_inference_steps (`int`, *optional*, defaults to 7):
|
||||
The number of denoising steps for super resolution. More denoising steps usually lead to a higher
|
||||
quality image at the expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
image_embeddings (`torch.Tensor`, *optional*):
|
||||
Pre-defined image embeddings that can be derived from the image encoder. Pre-defined image embeddings
|
||||
can be passed for tasks like image interpolations. `image` can the be left to `None`.
|
||||
decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for the decoder.
|
||||
super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for the decoder.
|
||||
decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
batch_size = steps
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if isinstance(image, List):
|
||||
if len(image) != 2:
|
||||
raise AssertionError(
|
||||
f"Expected 'image' List to be of size 2, but passed 'image' length is {len(image)}"
|
||||
)
|
||||
elif not (isinstance(image[0], PIL.Image.Image) and isinstance(image[0], PIL.Image.Image)):
|
||||
raise AssertionError(
|
||||
f"Expected 'image' List to contain PIL.Image.Image, but passed 'image' contents are {type(image[0])} and {type(image[1])}"
|
||||
)
|
||||
elif isinstance(image, torch.FloatTensor):
|
||||
if image.shape[0] != 2:
|
||||
raise AssertionError(
|
||||
f"Expected 'image' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image' size is {image.shape[0]}"
|
||||
)
|
||||
elif isinstance(image_embeddings, torch.Tensor):
|
||||
if image_embeddings.shape[0] != 2:
|
||||
raise AssertionError(
|
||||
f"Expected 'image_embeddings' to be torch.FloatTensor of shape 2 in 0th dimension, but passed 'image_embeddings' shape is {image_embeddings.shape[0]}"
|
||||
)
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"Expected 'image' or 'image_embeddings' to be not None with types List[PIL.Image] or Torch.FloatTensor respectively. Received {type(image)} and {type(image_embeddings)} repsectively"
|
||||
)
|
||||
|
||||
original_image_embeddings = self._encode_image(
|
||||
image=image, device=device, num_images_per_prompt=1, image_embeddings=image_embeddings
|
||||
)
|
||||
|
||||
image_embeddings = []
|
||||
|
||||
for interp_step in torch.linspace(0, 1, steps):
|
||||
temp_image_embeddings = slerp(
|
||||
interp_step, original_image_embeddings[0], original_image_embeddings[1]
|
||||
).unsqueeze(0)
|
||||
image_embeddings.append(temp_image_embeddings)
|
||||
|
||||
image_embeddings = torch.cat(image_embeddings).to(device)
|
||||
|
||||
do_classifier_free_guidance = decoder_guidance_scale > 1.0
|
||||
|
||||
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
|
||||
prompt=["" for i in range(steps)],
|
||||
device=device,
|
||||
num_images_per_prompt=1,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
|
||||
image_embeddings=image_embeddings,
|
||||
prompt_embeds=prompt_embeds,
|
||||
text_encoder_hidden_states=text_encoder_hidden_states,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
if device.type == "mps":
|
||||
# HACK: MPS: There is a panic when padding bool tensors,
|
||||
# so cast to int tensor for the pad and back to bool afterwards
|
||||
text_mask = text_mask.type(torch.int)
|
||||
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
|
||||
decoder_text_mask = decoder_text_mask.type(torch.bool)
|
||||
else:
|
||||
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True)
|
||||
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
text_encoder_hidden_states.dtype,
|
||||
device,
|
||||
generator,
|
||||
decoder_latents,
|
||||
self.decoder_scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
|
||||
|
||||
noise_pred = self.decoder(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
class_labels=additive_clip_time_embeddings,
|
||||
attention_mask=decoder_text_mask,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
if i + 1 == decoder_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = decoder_timesteps_tensor[i + 1]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
decoder_latents = self.decoder_scheduler.step(
|
||||
noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
|
||||
).prev_sample
|
||||
|
||||
decoder_latents = decoder_latents.clamp(-1, 1)
|
||||
|
||||
image_small = decoder_latents
|
||||
|
||||
# done decoder
|
||||
|
||||
# super res
|
||||
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
image_small.dtype,
|
||||
device,
|
||||
generator,
|
||||
super_res_latents,
|
||||
self.super_res_scheduler,
|
||||
)
|
||||
|
||||
if device.type == "mps":
|
||||
# MPS does not support many interpolations
|
||||
image_upscaled = F.interpolate(image_small, size=[height, width])
|
||||
else:
|
||||
interpolate_antialias = {}
|
||||
if "antialias" in inspect.signature(F.interpolate).parameters:
|
||||
interpolate_antialias["antialias"] = True
|
||||
|
||||
image_upscaled = F.interpolate(
|
||||
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
|
||||
# no classifier free guidance
|
||||
|
||||
if i == super_res_timesteps_tensor.shape[0] - 1:
|
||||
unet = self.super_res_last
|
||||
else:
|
||||
unet = self.super_res_first
|
||||
|
||||
latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
|
||||
|
||||
noise_pred = unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
).sample
|
||||
|
||||
if i + 1 == super_res_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = super_res_timesteps_tensor[i + 1]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
super_res_latents = self.super_res_scheduler.step(
|
||||
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
|
||||
).prev_sample
|
||||
|
||||
image = super_res_latents
|
||||
# done super res
|
||||
|
||||
# post processing
|
||||
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -24,6 +24,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
@@ -40,18 +41,71 @@ from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
check_min_version("0.15.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -306,6 +360,28 @@ def parse_args(input_args=None):
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A prompt that is used during validation to verify that the model is learning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_validation_images",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of images that should be generated during validation with `validation_prompt`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Run validation every X steps. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
@@ -508,6 +584,10 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
@@ -920,6 +1000,8 @@ def main(args):
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
check_min_version("0.15.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -54,7 +54,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
check_min_version("0.15.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
## Multi Token Textual Inversion
|
||||
The author of this project is [Isamu Isozaki](https://github.com/isamu-isozaki) - please make sure to tag the author for issue and PRs as well as @patrickvonplaten.
|
||||
|
||||
We add multi token support to textual inversion. I added
|
||||
1. num_vec_per_token for the number of used to reference that token
|
||||
2. progressive_tokens for progressively training the token from 1 token to 2 token etc
|
||||
3. progressive_tokens_max_steps for the max number of steps until we start full training
|
||||
4. vector_shuffle to shuffle vectors
|
||||
|
||||
Feel free to add these options to your training! In practice num_vec_per_token around 10+vector shuffle works great!
|
||||
|
||||
## Textual Inversion fine-tuning example
|
||||
|
||||
[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
|
||||
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
## Running on Colab
|
||||
|
||||
Colab for training
|
||||
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
||||
|
||||
Colab for inference
|
||||
[](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb)
|
||||
|
||||
## Running locally with PyTorch
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
|
||||
### Cat toy example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
Run the following command to authenticate your token
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
If you have already cloned the repo, then you won't need to go through these steps.
|
||||
|
||||
<br>
|
||||
|
||||
Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.
|
||||
|
||||
And launch the training using
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export DATA_DIR="path-to-dir-containing-images"
|
||||
|
||||
accelerate launch textual_inversion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$DATA_DIR \
|
||||
--learnable_property="object" \
|
||||
--placeholder_token="<cat-toy>" --initializer_token="toy" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--max_train_steps=3000 \
|
||||
--learning_rate=5.0e-04 --scale_lr \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="textual_inversion_cat"
|
||||
```
|
||||
|
||||
A full training run takes ~1 hour on one V100 GPU.
|
||||
|
||||
### Inference
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
model_id = "path-to-your-trained-model"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
prompt = "A <cat-toy> backpack"
|
||||
|
||||
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
|
||||
image.save("cat-backpack.png")
|
||||
```
|
||||
|
||||
|
||||
## Training with Flax/JAX
|
||||
|
||||
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
```bash
|
||||
pip install -U -r requirements_flax.txt
|
||||
```
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export DATA_DIR="path-to-dir-containing-images"
|
||||
|
||||
python textual_inversion_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$DATA_DIR \
|
||||
--learnable_property="object" \
|
||||
--placeholder_token="<cat-toy>" --initializer_token="toy" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--max_train_steps=3000 \
|
||||
--learning_rate=5.0e-04 --scale_lr \
|
||||
--output_dir="textual_inversion_cat"
|
||||
```
|
||||
It should be at least 70% faster than the PyTorch script with the same configuration.
|
||||
|
||||
### Training with xformers:
|
||||
You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
The main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing
|
||||
a photo of <concept>_0 <concept>_1 ... and so on
|
||||
and instead just do
|
||||
a photo of <concept>
|
||||
which gets translated to the above. This needs to work for both inference and training.
|
||||
For inference,
|
||||
the tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with
|
||||
it's underlying vectors
|
||||
For training,
|
||||
we would want to abstract away some logic like
|
||||
1. Adding tokens
|
||||
2. Updating gradient mask
|
||||
3. Saving embeddings
|
||||
to our Util class here.
|
||||
so
|
||||
TODO:
|
||||
1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x
|
||||
2. have mechanism for adding tokens x
|
||||
3. have mech for saving emebeddings x
|
||||
4. get mask to update x
|
||||
5. Loading tokens from embedding x
|
||||
6. Integrate to training x
|
||||
7. Test
|
||||
"""
|
||||
import copy
|
||||
import random
|
||||
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
|
||||
class MultiTokenCLIPTokenizer(CLIPTokenizer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.token_map = {}
|
||||
|
||||
def try_adding_tokens(self, placeholder_token, *args, **kwargs):
|
||||
num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs):
|
||||
output = []
|
||||
if num_vec_per_token == 1:
|
||||
self.try_adding_tokens(placeholder_token, *args, **kwargs)
|
||||
output.append(placeholder_token)
|
||||
else:
|
||||
output = []
|
||||
for i in range(num_vec_per_token):
|
||||
ith_token = placeholder_token + f"_{i}"
|
||||
self.try_adding_tokens(ith_token, *args, **kwargs)
|
||||
output.append(ith_token)
|
||||
# handle cases where there is a new placeholder token that contains the current placeholder token but is larger
|
||||
for token in self.token_map:
|
||||
if token in placeholder_token:
|
||||
raise ValueError(
|
||||
f"The tokenizer already has placeholder token {token} that can get confused with"
|
||||
f" {placeholder_token}keep placeholder tokens independent"
|
||||
)
|
||||
self.token_map[placeholder_token] = output
|
||||
|
||||
def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0):
|
||||
"""
|
||||
Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder
|
||||
can encode them
|
||||
vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119
|
||||
where shuffling tokens were found to force the model to learn the concepts more descriptively.
|
||||
"""
|
||||
if isinstance(text, list):
|
||||
output = []
|
||||
for i in range(len(text)):
|
||||
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
|
||||
return output
|
||||
for placeholder_token in self.token_map:
|
||||
if placeholder_token in text:
|
||||
tokens = self.token_map[placeholder_token]
|
||||
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
|
||||
if vector_shuffle:
|
||||
tokens = copy.copy(tokens)
|
||||
random.shuffle(tokens)
|
||||
text = text.replace(placeholder_token, " ".join(tokens))
|
||||
return text
|
||||
|
||||
def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs):
|
||||
return super().__call__(
|
||||
self.replace_placeholder_tokens_in_text(
|
||||
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, **kwargs):
|
||||
return super().encode(
|
||||
self.replace_placeholder_tokens_in_text(
|
||||
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,6 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
@@ -0,0 +1,8 @@
|
||||
transformers>=4.25.1
|
||||
flax
|
||||
optax
|
||||
torch
|
||||
torchvision
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
@@ -0,0 +1,941 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from multi_token_clip import MultiTokenCLIPTokenizer
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None):
|
||||
"""
|
||||
Add tokens to the tokenizer and set the initial value of token embeddings
|
||||
"""
|
||||
tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token)
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
|
||||
if initializer_token:
|
||||
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
|
||||
for i, placeholder_token_id in enumerate(placeholder_token_ids):
|
||||
token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]]
|
||||
else:
|
||||
for i, placeholder_token_id in enumerate(placeholder_token_ids):
|
||||
token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id])
|
||||
return placeholder_token
|
||||
|
||||
|
||||
def save_progress(tokenizer, text_encoder, accelerator, save_path):
|
||||
for placeholder_token in tokenizer.token_map:
|
||||
placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
|
||||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_ids]
|
||||
if len(placeholder_token_ids) == 1:
|
||||
learned_embeds = learned_embeds[None]
|
||||
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
def load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict):
|
||||
for placeholder_token in learned_embeds_dict:
|
||||
placeholder_embeds = learned_embeds_dict[placeholder_token]
|
||||
num_vec_per_token = placeholder_embeds.shape[0]
|
||||
placeholder_embeds = placeholder_embeds.to(dtype=text_encoder.dtype)
|
||||
add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=num_vec_per_token)
|
||||
placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
for i, placeholder_token_id in enumerate(placeholder_token_ids):
|
||||
token_embeds[placeholder_token_id] = placeholder_embeds[i]
|
||||
|
||||
|
||||
def load_multitoken_tokenizer_from_automatic(tokenizer, text_encoder, automatic_dict, placeholder_token):
|
||||
"""
|
||||
Automatic1111's tokens have format
|
||||
{'string_to_token': {'*': 265}, 'string_to_param': {'*': tensor([[ 0.0833, 0.0030, 0.0057, ..., -0.0264, -0.0616, -0.0529],
|
||||
[ 0.0058, -0.0190, -0.0584, ..., -0.0025, -0.0945, -0.0490],
|
||||
[ 0.0916, 0.0025, 0.0365, ..., -0.0685, -0.0124, 0.0728],
|
||||
[ 0.0812, -0.0199, -0.0100, ..., -0.0581, -0.0780, 0.0254]],
|
||||
requires_grad=True)}, 'name': 'FloralMarble-400', 'step': 399, 'sd_checkpoint': '4bdfc29c', 'sd_checkpoint_name': 'SD2.1-768'}
|
||||
"""
|
||||
learned_embeds_dict = {}
|
||||
learned_embeds_dict[placeholder_token] = automatic_dict["string_to_param"]["*"]
|
||||
load_multitoken_tokenizer(tokenizer, text_encoder, learned_embeds_dict)
|
||||
|
||||
|
||||
def get_mask(tokenizer, accelerator):
|
||||
# Get the mask of the weights that won't change
|
||||
mask = torch.ones(len(tokenizer)).to(accelerator.device, dtype=torch.bool)
|
||||
for placeholder_token in tokenizer.token_map:
|
||||
placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
|
||||
for i in range(len(placeholder_token_ids)):
|
||||
mask = mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i]).to(accelerator.device)
|
||||
return mask
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--progressive_tokens_max_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of steps until all tokens will be used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progressive_tokens",
|
||||
action="store_true",
|
||||
help="Progressively train the tokens. For example, first train for 1 token, then 2 tokens and so on.",
|
||||
)
|
||||
parser.add_argument("--vector_shuffle", action="store_true", help="Shuffling tokens durint training")
|
||||
parser.add_argument(
|
||||
"--num_vec_per_token",
|
||||
type=int,
|
||||
default=1,
|
||||
help=(
|
||||
"The number of vectors used to represent the placeholder token. The higher the number, the better the"
|
||||
" result at the cost of editability. This can be fixed by prompt editing."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only_save_embeds",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Save only the embeddings for the new concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A token to use as a placeholder for the concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
|
||||
)
|
||||
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
|
||||
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A prompt that is used during validation to verify that the model is learning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_validation_images",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of images that should be generated during validation with `validation_prompt`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_epochs",
|
||||
type=int,
|
||||
default=50,
|
||||
help=(
|
||||
"Run validation every X epochs. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
||||
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
||||
" for more docs"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.train_data_dir is None:
|
||||
raise ValueError("You must specify a train data directory.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
vector_shuffle=False,
|
||||
progressive_tokens=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
self.vector_shuffle = vector_shuffle
|
||||
self.progressive_tokens = progressive_tokens
|
||||
self.prop_tokens_to_load = 0
|
||||
|
||||
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL_INTERPOLATION["linear"],
|
||||
"bilinear": PIL_INTERPOLATION["bilinear"],
|
||||
"bicubic": PIL_INTERPOLATION["bicubic"],
|
||||
"lanczos": PIL_INTERPOLATION["lanczos"],
|
||||
}[interpolation]
|
||||
|
||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer.encode(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
vector_shuffle=self.vector_shuffle,
|
||||
prop_tokens_to_load=self.prop_tokens_to_load if self.progressive_tokens else 1.0,
|
||||
)[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
logging_dir=logging_dir,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
||||
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load tokenizer
|
||||
if args.tokenizer_name:
|
||||
tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = MultiTokenCLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
if is_xformers_available():
|
||||
try:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not enable memory efficient attention. Make sure xformers is installed"
|
||||
f" correctly and a GPU is available: {e}"
|
||||
)
|
||||
add_tokens(tokenizer, text_encoder, args.placeholder_token, args.num_vec_per_token, args.initializer_token)
|
||||
|
||||
# Freeze vae and unet
|
||||
vae.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.text_model.encoder.requires_grad_(False)
|
||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# Keep unet in train mode if we are using gradient checkpointing to save memory.
|
||||
# The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
|
||||
unet.train()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.train_data_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
placeholder_token=args.placeholder_token,
|
||||
repeats=args.repeats,
|
||||
learnable_property=args.learnable_property,
|
||||
center_crop=args.center_crop,
|
||||
set="train",
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# For mixed precision training we cast the unet and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move vae and unet to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion", config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# keep original embeddings as reference
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
if args.progressive_tokens:
|
||||
train_dataset.prop_tokens_to_load = float(global_step) / args.progressive_tokens_max_steps
|
||||
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = get_mask(tokenizer, accelerator)
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % args.save_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(tokenizer, text_encoder, accelerator, save_path)
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = (
|
||||
None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and args.only_save_embeds:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not args.only_save_embeds
|
||||
if save_full_model:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
# Save the newly trained embeddings
|
||||
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
||||
save_progress(tokenizer, text_encoder, accelerator, save_path)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,668 @@
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import PIL
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from flax import jax_utils
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import shard
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
|
||||
|
||||
from diffusers import (
|
||||
FlaxAutoencoderKL,
|
||||
FlaxDDPMScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
FlaxStableDiffusionPipeline,
|
||||
FlaxUNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
|
||||
from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
"nearest": PIL.Image.Resampling.NEAREST,
|
||||
}
|
||||
else:
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.BILINEAR,
|
||||
"bicubic": PIL.Image.BICUBIC,
|
||||
"lanczos": PIL.Image.LANCZOS,
|
||||
"nearest": PIL.Image.NEAREST,
|
||||
}
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A token to use as a placeholder for the concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
|
||||
)
|
||||
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
|
||||
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument(
|
||||
"--use_auth_token",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
|
||||
" private models)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.train_data_dir is None:
|
||||
raise ValueError("You must specify a train data directory.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
flip_p=0.5,
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
):
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL_INTERPOLATION["linear"],
|
||||
"bilinear": PIL_INTERPOLATION["bilinear"],
|
||||
"bicubic": PIL_INTERPOLATION["bicubic"],
|
||||
"lanczos": PIL_INTERPOLATION["lanczos"],
|
||||
}[interpolation]
|
||||
|
||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
example = {}
|
||||
image = Image.open(self.image_paths[i % self.num_images])
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
placeholder_string = self.placeholder_token
|
||||
text = random.choice(self.templates).format(placeholder_string)
|
||||
|
||||
example["input_ids"] = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0]
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
|
||||
image = self.flip_transform(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def resize_token_embeddings(model, new_num_tokens, initializer_token_id, placeholder_token_id, rng):
|
||||
if model.config.vocab_size == new_num_tokens or new_num_tokens is None:
|
||||
return
|
||||
model.config.vocab_size = new_num_tokens
|
||||
|
||||
params = model.params
|
||||
old_embeddings = params["text_model"]["embeddings"]["token_embedding"]["embedding"]
|
||||
old_num_tokens, emb_dim = old_embeddings.shape
|
||||
|
||||
initializer = jax.nn.initializers.normal()
|
||||
|
||||
new_embeddings = initializer(rng, (new_num_tokens, emb_dim))
|
||||
new_embeddings = new_embeddings.at[:old_num_tokens].set(old_embeddings)
|
||||
new_embeddings = new_embeddings.at[placeholder_token_id].set(new_embeddings[initializer_token_id])
|
||||
params["text_model"]["embeddings"]["token_embedding"]["embedding"] = new_embeddings
|
||||
|
||||
model.params = params
|
||||
return model
|
||||
|
||||
|
||||
def get_params_to_save(params):
|
||||
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
if jax.process_index() == 0:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
create_repo(repo_name, exist_ok=True, token=args.hub_token)
|
||||
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
# Setup logging, we only want one process per machine to log things on the screen.
|
||||
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
||||
if jax.process_index() == 0:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
|
||||
# Load the tokenizer and add the placeholder token as a additional special token
|
||||
if args.tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
|
||||
# Create sampling rng
|
||||
rng = jax.random.PRNGKey(args.seed)
|
||||
rng, _ = jax.random.split(rng)
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder = resize_token_embeddings(
|
||||
text_encoder, len(tokenizer), initializer_token_id, placeholder_token_id, rng
|
||||
)
|
||||
original_token_embeds = text_encoder.params["text_model"]["embeddings"]["token_embedding"]["embedding"]
|
||||
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.train_data_dir,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
placeholder_token=args.placeholder_token,
|
||||
repeats=args.repeats,
|
||||
learnable_property=args.learnable_property,
|
||||
center_crop=args.center_crop,
|
||||
set="train",
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
input_ids = torch.stack([example["input_ids"] for example in examples])
|
||||
|
||||
batch = {"pixel_values": pixel_values, "input_ids": input_ids}
|
||||
batch = {k: v.numpy() for k, v in batch.items()}
|
||||
|
||||
return batch
|
||||
|
||||
total_train_batch_size = args.train_batch_size * jax.local_device_count()
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=total_train_batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# Optimization
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * total_train_batch_size
|
||||
|
||||
constant_scheduler = optax.constant_schedule(args.learning_rate)
|
||||
|
||||
optimizer = optax.adamw(
|
||||
learning_rate=constant_scheduler,
|
||||
b1=args.adam_beta1,
|
||||
b2=args.adam_beta2,
|
||||
eps=args.adam_epsilon,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
)
|
||||
|
||||
def create_mask(params, label_fn):
|
||||
def _map(params, mask, label_fn):
|
||||
for k in params:
|
||||
if label_fn(k):
|
||||
mask[k] = "token_embedding"
|
||||
else:
|
||||
if isinstance(params[k], dict):
|
||||
mask[k] = {}
|
||||
_map(params[k], mask[k], label_fn)
|
||||
else:
|
||||
mask[k] = "zero"
|
||||
|
||||
mask = {}
|
||||
_map(params, mask, label_fn)
|
||||
return mask
|
||||
|
||||
def zero_grads():
|
||||
# from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
|
||||
def init_fn(_):
|
||||
return ()
|
||||
|
||||
def update_fn(updates, state, params=None):
|
||||
return jax.tree_util.tree_map(jnp.zeros_like, updates), ()
|
||||
|
||||
return optax.GradientTransformation(init_fn, update_fn)
|
||||
|
||||
# Zero out gradients of layers other than the token embedding layer
|
||||
tx = optax.multi_transform(
|
||||
{"token_embedding": optimizer, "zero": zero_grads()},
|
||||
create_mask(text_encoder.params, lambda s: s == "token_embedding"),
|
||||
)
|
||||
|
||||
state = train_state.TrainState.create(apply_fn=text_encoder.__call__, params=text_encoder.params, tx=tx)
|
||||
|
||||
noise_scheduler = FlaxDDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
|
||||
)
|
||||
noise_scheduler_state = noise_scheduler.create_state()
|
||||
|
||||
# Initialize our training
|
||||
train_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
# Define gradient train step fn
|
||||
def train_step(state, vae_params, unet_params, batch, train_rng):
|
||||
dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3)
|
||||
|
||||
def compute_loss(params):
|
||||
vae_outputs = vae.apply(
|
||||
{"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode
|
||||
)
|
||||
latents = vae_outputs.latent_dist.sample(sample_rng)
|
||||
# (NHWC) -> (NCHW)
|
||||
latents = jnp.transpose(latents, (0, 3, 1, 2))
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
noise_rng, timestep_rng = jax.random.split(sample_rng)
|
||||
noise = jax.random.normal(noise_rng, latents.shape)
|
||||
bsz = latents.shape[0]
|
||||
timesteps = jax.random.randint(
|
||||
timestep_rng,
|
||||
(bsz,),
|
||||
0,
|
||||
noise_scheduler.config.num_train_timesteps,
|
||||
)
|
||||
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
|
||||
encoder_hidden_states = state.apply_fn(
|
||||
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
|
||||
)[0]
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred = unet.apply(
|
||||
{"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
|
||||
).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = (target - model_pred) ** 2
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
grad_fn = jax.value_and_grad(compute_loss)
|
||||
loss, grad = grad_fn(state.params)
|
||||
grad = jax.lax.pmean(grad, "batch")
|
||||
new_state = state.apply_gradients(grads=grad)
|
||||
|
||||
# Keep the token embeddings fixed except the newly added embeddings for the concept,
|
||||
# as we only want to optimize the concept embeddings
|
||||
token_embeds = original_token_embeds.at[placeholder_token_id].set(
|
||||
new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"][placeholder_token_id]
|
||||
)
|
||||
new_state.params["text_model"]["embeddings"]["token_embedding"]["embedding"] = token_embeds
|
||||
|
||||
metrics = {"loss": loss}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
return new_state, metrics, new_train_rng
|
||||
|
||||
# Create parallel version of the train and eval step
|
||||
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
||||
|
||||
# Replicate the train state on each device
|
||||
state = jax_utils.replicate(state)
|
||||
vae_params = jax_utils.replicate(vae_params)
|
||||
unet_params = jax_utils.replicate(unet_params)
|
||||
|
||||
# Train!
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
global_step = 0
|
||||
|
||||
epochs = tqdm(range(args.num_train_epochs), desc=f"Epoch ... (1/{args.num_train_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
|
||||
train_metrics = []
|
||||
|
||||
steps_per_epoch = len(train_dataset) // total_train_batch_size
|
||||
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
|
||||
# train
|
||||
for batch in train_dataloader:
|
||||
batch = shard(batch)
|
||||
state, train_metric, train_rngs = p_train_step(state, vae_params, unet_params, batch, train_rngs)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_step_progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
train_metric = jax_utils.unreplicate(train_metric)
|
||||
|
||||
train_step_progress_bar.close()
|
||||
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if jax.process_index() == 0:
|
||||
scheduler = FlaxPNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||
)
|
||||
safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", from_pt=True
|
||||
)
|
||||
pipeline = FlaxStableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
)
|
||||
|
||||
pipeline.save_pretrained(
|
||||
args.output_dir,
|
||||
params={
|
||||
"text_encoder": get_params_to_save(state.params),
|
||||
"vae": get_params_to_save(vae_params),
|
||||
"unet": get_params_to_save(unet_params),
|
||||
"safety_checker": safety_checker.params,
|
||||
},
|
||||
)
|
||||
|
||||
# Also save the newly trained embeddings
|
||||
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"]["embedding"][
|
||||
placeholder_token_id
|
||||
]
|
||||
learned_embeds_dict = {args.placeholder_token: learned_embeds}
|
||||
jnp.save(os.path.join(args.output_dir, "learned_embeds.npy"), learned_embeds_dict)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -127,7 +127,7 @@ on consumer GPUs like Tesla T4, Tesla V100.
|
||||
|
||||
### Training
|
||||
|
||||
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://hf.colambdalabs/pokemon-blip-captions).
|
||||
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion v1-4](https://hf.co/CompVis/stable-diffusion-v1-4) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
check_min_version("0.15.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
check_min_version("0.15.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,6 +48,13 @@ def parse_args():
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
@@ -386,15 +393,17 @@ def main():
|
||||
weight_dtype = jnp.bfloat16
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, revision=args.revision, subfolder="tokenizer"
|
||||
)
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
|
||||
args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", dtype=weight_dtype
|
||||
)
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
|
||||
args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype
|
||||
)
|
||||
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype
|
||||
args.pretrained_model_name_or_path, revision=args.revision, subfolder="unet", dtype=weight_dtype
|
||||
)
|
||||
|
||||
# Optimization
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
check_min_version("0.15.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
check_min_version("0.15.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
check_min_version("0.15.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -121,6 +121,12 @@ def parse_args():
|
||||
default=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
@@ -136,6 +142,13 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
@@ -420,9 +433,15 @@ def main():
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
||||
)
|
||||
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
# Create sampling rng
|
||||
rng = jax.random.PRNGKey(args.seed)
|
||||
@@ -619,6 +638,14 @@ def main():
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
if global_step % args.save_steps == 0:
|
||||
learned_embeds = get_params_to_save(state.params)["text_model"]["embeddings"]["token_embedding"][
|
||||
"embedding"
|
||||
][placeholder_token_id]
|
||||
learned_embeds_dict = {args.placeholder_token: learned_embeds}
|
||||
jnp.save(
|
||||
os.path.join(args.output_dir, "learned_embeds-" + str(global_step) + ".npy"), learned_embeds_dict
|
||||
)
|
||||
|
||||
train_metric = jax_utils.unreplicate(train_metric)
|
||||
|
||||
|
||||
@@ -24,10 +24,11 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.14.0.dev0")
|
||||
check_min_version("0.15.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -259,6 +260,9 @@ def parse_args():
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -410,6 +414,19 @@ def main(args):
|
||||
model_config=model.config,
|
||||
)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# Initialize the scheduler
|
||||
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
if accepts_prediction_type:
|
||||
@@ -631,7 +648,7 @@ def main(args):
|
||||
if is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
tracker = accelerator.get_tracker("tensorboard", unwrap=True)
|
||||
else:
|
||||
tracker = accelerator.get_tracker()
|
||||
tracker = accelerator.get_tracker("tensorboard")
|
||||
tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)
|
||||
elif args.logger == "wandb":
|
||||
# Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" Conversion script for the LoRA's safetensors checkpoints. """
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
|
||||
def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET, LORA_PREFIX_TEXT_ENCODER, alpha):
|
||||
# load base model
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
||||
|
||||
# load LoRA weight from .safetensors
|
||||
state_dict = load_file(checkpoint_path)
|
||||
|
||||
visited = []
|
||||
|
||||
# directly update weight in diffusers model
|
||||
for key in state_dict:
|
||||
# it is suggested to print out the key, it usually will be something like below
|
||||
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||
|
||||
# as we have set the alpha beforehand, so just skip
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
|
||||
if "text" in key:
|
||||
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
||||
curr_layer = pipeline.text_encoder
|
||||
else:
|
||||
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
||||
curr_layer = pipeline.unet
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
try:
|
||||
curr_layer = curr_layer.__getattr__(temp_name)
|
||||
if len(layer_infos) > 0:
|
||||
temp_name = layer_infos.pop(0)
|
||||
elif len(layer_infos) == 0:
|
||||
break
|
||||
except Exception:
|
||||
if len(temp_name) > 0:
|
||||
temp_name += "_" + layer_infos.pop(0)
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||
pair_keys.append(key)
|
||||
else:
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
||||
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument(
|
||||
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_prefix_text_encoder",
|
||||
default="lora_te",
|
||||
type=str,
|
||||
help="The prefix of text encoder weight in safetensors",
|
||||
)
|
||||
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
|
||||
parser.add_argument(
|
||||
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
|
||||
)
|
||||
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
base_model_path = args.base_model_path
|
||||
checkpoint_path = args.checkpoint_path
|
||||
dump_path = args.dump_path
|
||||
lora_prefix_unet = args.lora_prefix_unet
|
||||
lora_prefix_text_encoder = args.lora_prefix_text_encoder
|
||||
alpha = args.alpha
|
||||
|
||||
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
|
||||
|
||||
pipe = pipe.to(args.device)
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.onnx import export
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
|
||||
is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11")
|
||||
|
||||
|
||||
def onnx_export(
|
||||
model,
|
||||
model_args: tuple,
|
||||
output_path: Path,
|
||||
ordered_input_names,
|
||||
output_names,
|
||||
dynamic_axes,
|
||||
opset,
|
||||
use_external_data_format=False,
|
||||
):
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
|
||||
# so we check the torch version for backwards compatibility
|
||||
if is_torch_less_than_1_11:
|
||||
export(
|
||||
model,
|
||||
model_args,
|
||||
f=output_path.as_posix(),
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=use_external_data_format,
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
else:
|
||||
export(
|
||||
model,
|
||||
model_args,
|
||||
f=output_path.as_posix(),
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
if fp16 and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif fp16 and not torch.cuda.is_available():
|
||||
raise ValueError("`float16` model export is only supported on GPUs with CUDA")
|
||||
else:
|
||||
device = "cpu"
|
||||
output_path = Path(output_path)
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = AutoencoderKL.from_pretrained(model_path + "/vae")
|
||||
vae_latent_channels = vae_decoder.config.latent_channels
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_decoder.decode
|
||||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, 25, 25).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
dynamic_axes={
|
||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del vae_decoder
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).",
|
||||
)
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument(
|
||||
"--opset",
|
||||
default=14,
|
||||
type=int,
|
||||
help="The version of the ONNX operator set to use.",
|
||||
)
|
||||
parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args.output_path)
|
||||
convert_models(args.model_path, args.output_path, args.opset, args.fp16)
|
||||
print("SD: Done: ONNX")
|
||||
@@ -80,6 +80,7 @@ from setuptools import find_packages, setup
|
||||
_deps = [
|
||||
"Pillow", # keep the PIL.Image.Resampling deprecation away
|
||||
"accelerate>=0.11.0",
|
||||
"compel==0.1.8",
|
||||
"black~=23.1",
|
||||
"datasets",
|
||||
"filelock",
|
||||
@@ -182,6 +183,7 @@ extras["quality"] = deps_list("black", "isort", "ruff", "hf-doc-builder")
|
||||
extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "Jinja2")
|
||||
extras["test"] = deps_list(
|
||||
"compel",
|
||||
"datasets",
|
||||
"Jinja2",
|
||||
"k-diffusion",
|
||||
@@ -219,7 +221,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.14.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.15.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.14.0.dev0"
|
||||
__version__ = "0.15.0.dev0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import (
|
||||
@@ -158,6 +158,7 @@ else:
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionInpaintPipelineLegacy,
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
deps = {
|
||||
"Pillow": "Pillow",
|
||||
"accelerate": "accelerate>=0.11.0",
|
||||
"compel": "compel==0.1.8",
|
||||
"black": "black~=23.1",
|
||||
"datasets": "datasets",
|
||||
"filelock": "filelock",
|
||||
|
||||
+61
-18
@@ -19,13 +19,18 @@ import torch
|
||||
|
||||
from .models.cross_attention import LoRACrossAttnProcessor
|
||||
from .models.modeling_utils import _get_model_file
|
||||
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
|
||||
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
|
||||
|
||||
class AttnProcsLayers(torch.nn.Module):
|
||||
@@ -136,28 +141,53 @@ class UNet2DConditionLoadersMixin:
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"):
|
||||
if weight_name is None:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except EnvironmentError:
|
||||
if weight_name == LORA_WEIGHT_NAME_SAFE:
|
||||
weight_name = None
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
@@ -195,8 +225,9 @@ class UNet2DConditionLoadersMixin:
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
weights_name: str = LORA_WEIGHT_NAME,
|
||||
weights_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = False,
|
||||
):
|
||||
r"""
|
||||
Save an attention processor to a directory, so that it can be re-loaded using the
|
||||
@@ -219,7 +250,13 @@ class UNet2DConditionLoadersMixin:
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
save_function = torch.save
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
@@ -237,6 +274,12 @@ class UNet2DConditionLoadersMixin:
|
||||
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
|
||||
os.remove(full_filename)
|
||||
|
||||
if weights_name is None:
|
||||
if safe_serialization:
|
||||
weights_name = LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weights_name = LORA_WEIGHT_NAME
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, weights_name))
|
||||
|
||||
|
||||
@@ -260,7 +260,7 @@ class CrossAttention(nn.Module):
|
||||
deprecate(
|
||||
"batch_size=None",
|
||||
"0.0.15",
|
||||
message=(
|
||||
(
|
||||
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
||||
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
||||
" `prepare_attention_mask` when preparing the attention_mask."
|
||||
|
||||
@@ -195,7 +195,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
) -> Union[UNet1DOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): `(batch_size, sample_size, num_channels)` noisy inputs tensor
|
||||
sample (`torch.FloatTensor`): `(batch_size, num_channels, sample_size)` noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -70,8 +70,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
|
||||
class_embed_type (`str`, *optional*, defaults to None):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
|
||||
@@ -90,8 +90,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
|
||||
class_embed_type (`str`, *optional*, defaults to None):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, `"identity"`, or `"projection"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
|
||||
@@ -93,6 +93,7 @@ else:
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionInpaintPipelineLegacy,
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
)
|
||||
|
||||
|
||||
@@ -214,6 +214,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -234,6 +238,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -220,6 +220,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -240,6 +244,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -49,6 +49,7 @@ from ..utils import (
|
||||
get_class_from_dynamic_module,
|
||||
http_user_agent,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
@@ -63,7 +64,11 @@ if is_transformers_available():
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
@@ -176,7 +181,13 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:
|
||||
|
||||
def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
|
||||
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||
weight_names = [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME]
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
@@ -331,15 +342,50 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
|
||||
if torch_device is None:
|
||||
return self
|
||||
|
||||
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
|
||||
def module_is_sequentially_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
|
||||
|
||||
def module_is_offloaded(module):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
|
||||
|
||||
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
|
||||
pipeline_is_sequentially_offloaded = any(
|
||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||
)
|
||||
if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda":
|
||||
raise ValueError(
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
)
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and torch.device(torch_device).type == "cuda":
|
||||
logger.warning(
|
||||
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
|
||||
)
|
||||
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(torch_device) in ["cpu"]
|
||||
and not silence_dtype_warnings
|
||||
and not is_offloaded
|
||||
):
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
@@ -604,7 +650,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
]
|
||||
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", ".onnx"]
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...schedulers import RePaintScheduler
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
|
||||
from ...utils import PIL_INTERPOLATION, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
@@ -90,7 +90,6 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
@@ -124,9 +123,7 @@ class RePaintPipeline(DiffusionPipeline):
|
||||
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
message = "Please use `image` instead of `original_image`."
|
||||
original_image = deprecate("original_image", "0.15.0", message, take_from=kwargs)
|
||||
original_image = original_image or image
|
||||
original_image = image
|
||||
|
||||
original_image = _preprocess_image(original_image)
|
||||
original_image = original_image.to(device=self.device, dtype=self.unet.dtype)
|
||||
|
||||
@@ -104,6 +104,7 @@ else:
|
||||
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
|
||||
from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline
|
||||
from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy
|
||||
from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
import flax
|
||||
|
||||
@@ -237,6 +237,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -258,6 +262,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -0,0 +1,290 @@
|
||||
from logging import getLogger
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ..onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from ..pipeline_utils import ImagePipelineOutput
|
||||
from . import StableDiffusionUpscalePipeline
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
NUM_LATENT_CHANNELS = 4
|
||||
NUM_UNET_INPUT_CHANNELS = 7
|
||||
|
||||
ORT_TO_PT_TYPE = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
image = [np.array(i.resize((w, h)))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: OnnxRuntimeModel,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
tokenizer: Any,
|
||||
unet: OnnxRuntimeModel,
|
||||
low_res_scheduler: DDPMScheduler,
|
||||
scheduler: Any,
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]],
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
noise_level: int = 20,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
):
|
||||
# 1. Check inputs
|
||||
self.check_inputs(prompt, image, noise_level, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
latents_dtype = ORT_TO_PT_TYPE[str(text_embeddings.dtype)]
|
||||
|
||||
# 4. Preprocess image
|
||||
image = preprocess(image)
|
||||
image = image.cpu()
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Add noise to image
|
||||
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
|
||||
noise = torch.randn(image.shape, generator=generator, device=device, dtype=latents_dtype)
|
||||
image = self.low_res_scheduler.add_noise(image, noise, noise_level)
|
||||
|
||||
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
||||
image = np.concatenate([image] * batch_multiplier * num_images_per_prompt)
|
||||
noise_level = np.concatenate([noise_level] * image.shape[0])
|
||||
|
||||
# 6. Prepare latent variables
|
||||
height, width = image.shape[2:]
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
NUM_LATENT_CHANNELS,
|
||||
height,
|
||||
width,
|
||||
latents_dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Check that sizes of image and latents match
|
||||
num_channels_image = image.shape[1]
|
||||
if NUM_LATENT_CHANNELS + num_channels_image != NUM_UNET_INPUT_CHANNELS:
|
||||
raise ValueError(
|
||||
"Incorrect configuration settings! The config of `pipeline.unet` expects"
|
||||
f" {NUM_UNET_INPUT_CHANNELS} but received `num_channels_latents`: {NUM_LATENT_CHANNELS} +"
|
||||
f" `num_channels_image`: {num_channels_image} "
|
||||
f" = {NUM_LATENT_CHANNELS+num_channels_image}. Please verify the config of"
|
||||
" `pipeline.unet` or your `image` input."
|
||||
)
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = np.concatenate([latent_model_input, image], axis=1)
|
||||
|
||||
# timestep to tensor
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
class_labels=noise_level.astype(np.int64),
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
torch.from_numpy(noise_pred), t, latents, **extra_step_kwargs
|
||||
).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 10. Post-processing
|
||||
image = self.decode_latents(latents.float())
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
image = self.vae(latent_sample=latents)[0]
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
return image
|
||||
|
||||
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
# if hasattr(text_inputs, "attention_mask"):
|
||||
# attention_mask = text_inputs.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
# no positional arguments to text_encoder
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=text_input_ids.int().to(device),
|
||||
# attention_mask=attention_mask,
|
||||
)
|
||||
text_embeddings = text_embeddings[0]
|
||||
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt)
|
||||
text_embeddings = text_embeddings.reshape(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# if hasattr(uncond_input, "attention_mask"):
|
||||
# attention_mask = uncond_input.attention_mask.to(device)
|
||||
# else:
|
||||
# attention_mask = None
|
||||
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.int().to(device),
|
||||
# attention_mask=attention_mask,
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
|
||||
uncond_embeddings = uncond_embeddings.reshape(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
@@ -217,6 +217,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -237,6 +241,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
+35
-4
@@ -263,6 +263,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -513,8 +517,30 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if (indices is None) or (indices is not None and not isinstance(indices, List)):
|
||||
raise ValueError(f"`indices` has to be a list but is {type(indices)}")
|
||||
indices_is_list_ints = isinstance(indices, list) and isinstance(indices[0], int)
|
||||
indices_is_list_list_ints = (
|
||||
isinstance(indices, list) and isinstance(indices[0], list) and isinstance(indices[0][0], int)
|
||||
)
|
||||
|
||||
if not indices_is_list_ints and not indices_is_list_list_ints:
|
||||
raise TypeError("`indices` must be a list of ints or a list of a list of ints")
|
||||
|
||||
if indices_is_list_ints:
|
||||
indices_batch_size = 1
|
||||
elif indices_is_list_list_ints:
|
||||
indices_batch_size = len(indices)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
prompt_batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
prompt_batch_size = len(prompt)
|
||||
elif prompt_embeds is not None:
|
||||
prompt_batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if indices_batch_size != prompt_batch_size:
|
||||
raise ValueError(
|
||||
f"indices batch size must be same as prompt batch size. indices batch size: {indices_batch_size}, prompt batch size: {prompt_batch_size}"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
@@ -671,7 +697,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
token_indices: List[int],
|
||||
token_indices: Union[List[int], List[List[int]]],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
@@ -691,6 +717,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
max_iter_to_alter: int = 25,
|
||||
thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
|
||||
scale_factor: int = 20,
|
||||
attn_res: int = 16,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -762,6 +789,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
|
||||
scale_factor (`int`, *optional*, default to 20):
|
||||
Scale factor that controls the step size of each Attend and Excite update.
|
||||
attn_res (`int`, *optional*, default to 16):
|
||||
The resolution of most semantic attention map.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -834,7 +863,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
self.attention_store = AttentionStore()
|
||||
self.attention_store = AttentionStore(attn_res=attn_res)
|
||||
self.register_attention_control()
|
||||
|
||||
# default config for step size from original repo
|
||||
@@ -847,7 +876,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
|
||||
if isinstance(token_indices[0], int):
|
||||
token_indices = [token_indices]
|
||||
|
||||
indices = []
|
||||
|
||||
for ind in token_indices:
|
||||
indices = indices + [ind] * num_images_per_prompt
|
||||
|
||||
|
||||
@@ -14,14 +14,17 @@
|
||||
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from torch import device, nn
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...models.controlnet import ControlNetOutput
|
||||
from ...models.modeling_utils import get_parameter_device, get_parameter_dtype
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
@@ -85,6 +88,60 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class MultiControlNet(nn.Module):
|
||||
def __init__(self, controlnets: List[ControlNetModel]):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
"""
|
||||
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
||||
device).
|
||||
"""
|
||||
return get_parameter_device(self)
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""
|
||||
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||
"""
|
||||
return get_parameter_dtype(self)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.FloatTensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
num_images_per_net = controlnet_cond.shape[0] // len(self.nets)
|
||||
conds = controlnet_cond[None, :].reshape((num_images_per_net, -1) + controlnet_cond.shape[1:])
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = 0
|
||||
for cond, controlnet in zip(conds, self.nets):
|
||||
down, mid = self.controlnet(
|
||||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
cond,
|
||||
class_labels,
|
||||
timestep_cond,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
return_dict,
|
||||
)
|
||||
down_block_res_samples += down
|
||||
mid_block_res_sample += mid
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
|
||||
class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
||||
@@ -146,6 +203,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = MultiControlNet(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -517,7 +577,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
num_controlnets = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
|
||||
image_batch_size = image.shape[0] // num_controlnets
|
||||
|
||||
if image_batch_size != image.shape[0] * num_controlnets:
|
||||
raise ValueError("TODO: Good error message here")
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
@@ -716,7 +780,12 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image = torch.cat([image] * 2)
|
||||
num_control = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
|
||||
image = image[None, :].reshape(num_control, -1, *image.shape[1:])
|
||||
|
||||
# only repeat batch size, but not controlnet dim
|
||||
image = image.repeat(1, 2, 1, 1, 1)
|
||||
image = image.reshape((image.shape[:2].numel(),) + image.shape[2:])
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
@@ -225,6 +225,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -246,6 +250,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -272,6 +272,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -293,6 +297,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
+11
-3
@@ -42,7 +42,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
def preprocess_image(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
@@ -54,7 +54,7 @@ def preprocess_mask(mask, scale_factor=8):
|
||||
if not isinstance(mask, torch.FloatTensor):
|
||||
mask = mask.convert("L")
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
||||
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
@@ -76,7 +76,7 @@ def preprocess_mask(mask, scale_factor=8):
|
||||
# (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape
|
||||
mask = mask.mean(dim=1, keepdim=True)
|
||||
h, w = mask.shape[-2:]
|
||||
h, w = map(lambda x: x - x % 32, (h, w)) # resize to integer multiple of 32
|
||||
h, w = map(lambda x: x - x % 8, (h, w)) # resize to integer multiple of 8
|
||||
mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor))
|
||||
return mask
|
||||
|
||||
@@ -216,6 +216,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -237,6 +241,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
+45
-6
@@ -246,13 +246,19 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Check inputs
|
||||
self.check_inputs(prompt, callback_steps)
|
||||
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined.")
|
||||
|
||||
# 1. Define call parameters
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@@ -405,6 +411,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -426,6 +436,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
@@ -632,10 +646,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def check_inputs(self, prompt, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
def check_inputs(
|
||||
self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
@@ -644,6 +657,32 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
|
||||
@@ -137,6 +137,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -158,6 +162,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
@@ -158,6 +158,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
|
||||
@@ -372,6 +372,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@@ -1062,7 +1066,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
guidance_scale (`float`, *optional*, defaults to 1):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
|
||||
@@ -176,6 +176,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
|
||||
@@ -88,21 +88,22 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
|
||||
is_vae_scaling_factor_set_to_0_08333 = (
|
||||
hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
|
||||
)
|
||||
if not is_vae_scaling_factor_set_to_0_08333:
|
||||
deprecation_message = (
|
||||
"The configuration file of the vae does not contain `scaling_factor` or it is set to"
|
||||
f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned"
|
||||
" version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to 0.08333"
|
||||
" Please make sure to update the config accordingly, as not doing so might lead to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be"
|
||||
" very nice if you could open a Pull Request for the `vae/config.json` file"
|
||||
if hasattr(vae, "config"):
|
||||
# check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
|
||||
is_vae_scaling_factor_set_to_0_08333 = (
|
||||
hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
|
||||
)
|
||||
deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False)
|
||||
vae.register_to_config(scaling_factor=0.08333)
|
||||
if not is_vae_scaling_factor_set_to_0_08333:
|
||||
deprecation_message = (
|
||||
"The configuration file of the vae does not contain `scaling_factor` or it is set to"
|
||||
f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned"
|
||||
" version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to"
|
||||
" 0.08333 Please make sure to update the config accordingly, as not doing so might lead to"
|
||||
" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging"
|
||||
" Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file"
|
||||
)
|
||||
deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False)
|
||||
vae.register_to_config(scaling_factor=0.08333)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
|
||||
@@ -171,8 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
|
||||
class_embed_type (`str`, *optional*, defaults to None):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, `"identity"`, or `"projection"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
|
||||
@@ -98,7 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
option to clip predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, default `1.0`):
|
||||
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
set_alpha_to_one (`bool`, default `True`):
|
||||
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
||||
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
@@ -111,6 +113,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
||||
stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -128,6 +139,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -184,6 +199,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -286,9 +313,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
" `v_prediction`"
|
||||
)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
# 4. Clip or threshold "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
|
||||
@@ -98,11 +98,22 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
option to clip predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, default `1.0`):
|
||||
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
||||
stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -119,7 +130,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
clip_sample_range: Optional[float] = 1.0,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -226,6 +240,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
@@ -283,12 +308,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
" `v_prediction` for the DDPMScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
# 3. Clip or threshold "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = torch.clamp(
|
||||
pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
|
||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
||||
|
||||
@@ -96,7 +96,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://arxiv.org/abs/2205.11487).
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid woks when `thresholding=True`
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True`
|
||||
algorithm_type (`str`, default `deis`):
|
||||
the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
|
||||
the future
|
||||
@@ -194,6 +194,18 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
@@ -228,15 +240,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.type(orig_dtype)
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
|
||||
@@ -204,6 +204,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
@@ -247,15 +259,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.type(orig_dtype)
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
|
||||
@@ -237,6 +237,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sample = None
|
||||
self.orders = self.get_order_list(num_inference_steps)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
@@ -277,18 +289,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dtype = x0_pred.dtype
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)).float(),
|
||||
self.config.dynamic_thresholding_ratio,
|
||||
dim=1,
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.to(dtype)
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
|
||||
@@ -109,7 +109,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional):
|
||||
final timestep value (overrides value given at Scheduler instantiation).
|
||||
|
||||
"""
|
||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||
@@ -129,8 +130,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
sigma_min (`float`, optional):
|
||||
initial noise scale value (overrides value given at Scheduler instantiation).
|
||||
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
||||
sigma_max (`float`, optional):
|
||||
final noise scale value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional):
|
||||
final timestep value (overrides value given at Scheduler instantiation).
|
||||
|
||||
"""
|
||||
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
|
||||
|
||||
@@ -116,7 +116,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional):
|
||||
final timestep value (overrides value given at Scheduler instantiation).
|
||||
|
||||
"""
|
||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||
@@ -143,8 +144,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
sigma_min (`float`, optional):
|
||||
initial noise scale value (overrides value given at Scheduler instantiation).
|
||||
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
||||
sigma_max (`float`, optional):
|
||||
final noise scale value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional):
|
||||
final timestep value (overrides value given at Scheduler instantiation).
|
||||
"""
|
||||
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
|
||||
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
|
||||
|
||||
@@ -210,6 +210,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.solver_p:
|
||||
self.solver_p.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
@@ -245,15 +257,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.type(orig_dtype)
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
return x0_pred
|
||||
else:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
|
||||
@@ -203,8 +203,6 @@ class EMAModel:
|
||||
else:
|
||||
s_param.copy_(param)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
||||
"""
|
||||
Copy current averaged parameters into given collection of parameters.
|
||||
|
||||
@@ -62,6 +62,21 @@ class OnnxStableDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers", "onnx"])
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "onnx"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers", "onnx"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "onnx"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers", "onnx"])
|
||||
|
||||
|
||||
class StableDiffusionOnnxPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers", "onnx"]
|
||||
|
||||
|
||||
@@ -232,6 +232,14 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_tensorboard_available = False
|
||||
|
||||
|
||||
_compel_available = importlib.util.find_spec("compel")
|
||||
try:
|
||||
_compel_version = importlib_metadata.version("compel")
|
||||
logger.debug(f"Successfully imported compel version {_compel_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_compel_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
@@ -296,6 +304,10 @@ def is_tensorboard_available():
|
||||
return _tensorboard_available
|
||||
|
||||
|
||||
def is_compel_available():
|
||||
return _compel_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -368,6 +380,12 @@ TENSORBOARD_IMPORT_ERROR = """
|
||||
install tensorboard`
|
||||
"""
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
COMPEL_IMPORT_ERROR = """
|
||||
{0} requires the compel library but it was not found in your environment. You can install it with pip: `pip install compel`
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
||||
@@ -382,6 +400,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
|
||||
("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
|
||||
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
|
||||
("compel", (_compel_available, COMPEL_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import PIL.ImageOps
|
||||
import requests
|
||||
from packaging import version
|
||||
|
||||
from .import_utils import is_flax_available, is_onnx_available, is_torch_available
|
||||
from .import_utils import is_compel_available, is_flax_available, is_onnx_available, is_torch_available
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
@@ -175,6 +175,14 @@ def require_flax(test_case):
|
||||
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
|
||||
|
||||
|
||||
def require_compel(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
|
||||
the library is not installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
|
||||
|
||||
|
||||
def require_onnxruntime(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -372,6 +373,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
@@ -385,6 +387,101 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
# LoRA and no LoRA should NOT be the same
|
||||
assert (sample - old_sample).abs().max() > 1e-4
|
||||
|
||||
def test_lora_save_load_safetensors(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
old_sample = model(**inputs_dict).sample
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname, safe_serialization=True)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname)
|
||||
|
||||
with torch.no_grad():
|
||||
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
assert (sample - new_sample).abs().max() < 1e-4
|
||||
|
||||
# LoRA and no LoRA should NOT be the same
|
||||
assert (sample - old_sample).abs().max() > 1e-4
|
||||
|
||||
def test_lora_save_load_safetensors_load_torch(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
# Saving as torch, properly reloads with directly filename
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
|
||||
|
||||
def test_lora_on_off(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -160,19 +160,6 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
assert out_1.shape == (1, 64, 64, 3)
|
||||
assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
|
||||
|
||||
def test_paint_by_example_inpaint_with_num_images_per_prompt(self):
|
||||
device = "cpu"
|
||||
pipe = PaintByExamplePipeline(**self.get_dummy_components())
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=2).images
|
||||
|
||||
# check if the output is a list of 2 images
|
||||
assert len(images) == 2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -41,7 +41,7 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"negative_prompt_embeds",
|
||||
}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import floats_tensor
|
||||
from diffusers.utils.testing_utils import (
|
||||
is_onnx_available,
|
||||
load_image,
|
||||
nightly,
|
||||
require_onnxruntime,
|
||||
require_torch_gpu,
|
||||
)
|
||||
|
||||
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
|
||||
|
||||
|
||||
if is_onnx_available():
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
class OnnxStableDiffusionUpscalePipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase):
|
||||
# TODO: is there an appropriate internal test set?
|
||||
hub_checkpoint = "ssube/stable-diffusion-x4-upscaler-onnx"
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
image = floats_tensor((1, 3, 128, 128), rng=random.Random(seed))
|
||||
generator = torch.manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_pipeline_default_ddpm(self):
|
||||
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
# started as 128, should now be 512
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array(
|
||||
[0.6974782, 0.68902093, 0.70135885, 0.7583618, 0.7804545, 0.7854912, 0.78667426, 0.78743863, 0.78070223]
|
||||
)
|
||||
assert np.abs(image_slice - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_pndm(self):
|
||||
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config, skip_prk_steps=True)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array(
|
||||
[0.6898892, 0.59240556, 0.52499527, 0.58866215, 0.52258235, 0.52572715, 0.62414473, 0.6174387, 0.6214964]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_dpm_multistep(self):
|
||||
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array(
|
||||
[0.7659278, 0.76437664, 0.75579107, 0.7691116, 0.77666986, 0.7727672, 0.7758664, 0.7812226, 0.76942515]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_euler(self):
|
||||
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array(
|
||||
[0.6974782, 0.68902093, 0.70135885, 0.7583618, 0.7804545, 0.7854912, 0.78667426, 0.78743863, 0.78070223]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_euler_ancestral(self):
|
||||
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array(
|
||||
[0.77424496, 0.773601, 0.7645288, 0.7769598, 0.7772739, 0.7738688, 0.78187233, 0.77879584, 0.767043]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||
|
||||
|
||||
@nightly
|
||||
@require_onnxruntime
|
||||
@require_torch_gpu
|
||||
class OnnxStableDiffusionUpscalePipelineIntegrationTests(unittest.TestCase):
|
||||
@property
|
||||
def gpu_provider(self):
|
||||
return (
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"gpu_mem_limit": "15000000000", # 15GB
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def gpu_options(self):
|
||||
options = ort.SessionOptions()
|
||||
options.enable_mem_pattern = False
|
||||
return options
|
||||
|
||||
def test_inference_default_ddpm(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((128, 128))
|
||||
# using the PNDM scheduler by default
|
||||
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
"ssube/stable-diffusion-x4-upscaler-onnx",
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=10,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
images = output.images
|
||||
image_slice = images[0, 255:258, 383:386, -1]
|
||||
|
||||
assert images.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.4883, 0.4947, 0.4980, 0.4975, 0.4982, 0.4980, 0.5000, 0.5006, 0.4972])
|
||||
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
||||
|
||||
def test_inference_k_lms(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((128, 128))
|
||||
lms_scheduler = LMSDiscreteScheduler.from_pretrained(
|
||||
"ssube/stable-diffusion-x4-upscaler-onnx", subfolder="scheduler"
|
||||
)
|
||||
pipe = OnnxStableDiffusionUpscalePipeline.from_pretrained(
|
||||
"ssube/stable-diffusion-x4-upscaler-onnx",
|
||||
scheduler=lms_scheduler,
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=20,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
images = output.images
|
||||
image_slice = images[0, 255:258, 383:386, -1]
|
||||
|
||||
assert images.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array(
|
||||
[0.50173753, 0.50223356, 0.502039, 0.50233036, 0.5023725, 0.5022601, 0.5018758, 0.50234085, 0.50241566]
|
||||
)
|
||||
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
||||
@@ -477,43 +477,6 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_num_images_per_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
# test num_images_per_prompt=1 (default)
|
||||
images = sd_pipe(prompt, num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
images = sd_pipe([prompt] * batch_size, num_inference_steps=2, output_type="np").images
|
||||
|
||||
assert images.shape == (batch_size, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for single prompt
|
||||
num_images_per_prompt = 2
|
||||
images = sd_pipe(
|
||||
prompt, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
|
||||
).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
images = sd_pipe(
|
||||
[prompt] * batch_size, num_inference_steps=2, output_type="np", num_images_per_prompt=num_images_per_prompt
|
||||
).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
def test_stable_diffusion_long_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config)
|
||||
|
||||
@@ -143,42 +143,6 @@ class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unitte
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_img_variation_num_images_per_prompt(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImageVariationPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# test num_images_per_prompt=1 (default)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs).images
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt=1 (default) for batch of images
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = batch_size * [inputs["image"]]
|
||||
images = sd_pipe(**inputs).images
|
||||
|
||||
assert images.shape == (batch_size, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for single prompt
|
||||
num_images_per_prompt = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = batch_size * [inputs["image"]]
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -181,42 +181,6 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_img2img_num_images_per_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# test num_images_per_prompt=1 (default)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs).images
|
||||
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
||||
images = sd_pipe(**inputs).images
|
||||
|
||||
assert images.shape == (batch_size, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt for single prompt
|
||||
num_images_per_prompt = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
@skip_mps
|
||||
def test_save_load_local(self):
|
||||
return super().test_save_load_local()
|
||||
|
||||
@@ -151,19 +151,6 @@ class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
assert out_pil.shape == (1, 64, 64, 3)
|
||||
assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_with_num_images_per_prompt(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionInpaintPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=2).images
|
||||
|
||||
# check if the output is a list of 2 images
|
||||
assert len(images) == 2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -191,42 +191,6 @@ class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unitt
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_num_images_per_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionInstructPix2PixPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# test num_images_per_prompt=1 (default)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs).images
|
||||
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
||||
images = sd_pipe(**inputs).images
|
||||
|
||||
assert images.shape == (batch_size, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt for single prompt
|
||||
num_images_per_prompt = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -177,42 +177,6 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.Tes
|
||||
with self.assertRaises(ValueError):
|
||||
_ = sd_pipe(**inputs).images
|
||||
|
||||
def test_stable_diffusion_panorama_num_images_per_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPanoramaPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# test num_images_per_prompt=1 (default)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs).images
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
||||
images = sd_pipe(**inputs).images
|
||||
|
||||
assert images.shape == (batch_size, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for single prompt
|
||||
num_images_per_prompt = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -191,34 +191,6 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_num_images_per_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# test num_images_per_prompt=1 (default)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs).images
|
||||
|
||||
assert images.shape == (1, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt=2 for a single prompt
|
||||
num_images_per_prompt = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
||||
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
|
||||
|
||||
# Non-determinism caused by the scheduler optimizing the latent inputs during inference
|
||||
@unittest.skip("non-deterministic pipeline")
|
||||
def test_inference_batch_single_identical(self):
|
||||
|
||||
@@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
test_save_load_optional_components = False
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"image"}
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -340,42 +340,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_depth2img_num_images_per_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = StableDiffusionDepth2ImgPipeline(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# test num_images_per_prompt=1 (default)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = pipe(**inputs).images
|
||||
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
||||
images = pipe(**inputs).images
|
||||
|
||||
assert images.shape == (batch_size, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt for single prompt
|
||||
num_images_per_prompt = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * batch_size
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
def test_stable_diffusion_depth2img_pil(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -361,59 +361,6 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_unclip_image_variation_input_num_images_per_prompt(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
|
||||
pipeline_inputs["image"] = [
|
||||
pipeline_inputs["image"],
|
||||
pipeline_inputs["image"],
|
||||
]
|
||||
|
||||
output = pipe(**pipeline_inputs, num_images_per_prompt=2)
|
||||
image = output.images
|
||||
|
||||
tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
|
||||
tuple_pipeline_inputs["image"] = [
|
||||
tuple_pipeline_inputs["image"],
|
||||
tuple_pipeline_inputs["image"],
|
||||
]
|
||||
|
||||
image_from_tuple = pipe(
|
||||
**tuple_pipeline_inputs,
|
||||
num_images_per_prompt=2,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (4, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.9980,
|
||||
0.9997,
|
||||
0.0023,
|
||||
0.0029,
|
||||
0.9997,
|
||||
0.9985,
|
||||
0.9997,
|
||||
0.0010,
|
||||
0.9995,
|
||||
]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_unclip_passed_image_embed(self):
|
||||
device = torch.device("cpu")
|
||||
|
||||
|
||||
+69
-1
@@ -49,11 +49,12 @@ from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
UniPCMultistepScheduler,
|
||||
logging,
|
||||
)
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
@@ -584,6 +585,42 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert image_img2img.shape == (1, 32, 32, 3)
|
||||
assert image_text2img.shape == (1, 64, 64, 3)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_pipe_false_offload_warn(self):
|
||||
unet = self.dummy_cond_unet()
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
sd.enable_model_cpu_offload()
|
||||
|
||||
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
sd.to("cuda")
|
||||
|
||||
assert "It is strongly recommended against doing so" in str(cap_logger)
|
||||
|
||||
sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
|
||||
def test_set_scheduler(self):
|
||||
unet = self.dummy_cond_unet()
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
@@ -1022,6 +1059,37 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
|
||||
assert np.abs(image_0 - image_1).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@require_compel
|
||||
def test_weighted_prompts_compel(self):
|
||||
from compel import Compel
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
compel = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
|
||||
|
||||
prompt = "a red cat playing with a ball{}"
|
||||
|
||||
prompts = [prompt.format(s) for s in ["", "++", "--"]]
|
||||
|
||||
prompt_embeds = compel(prompts)
|
||||
|
||||
generator = [torch.Generator(device="cpu").manual_seed(33) for _ in range(prompt_embeds.shape[0])]
|
||||
|
||||
images = pipe(
|
||||
prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20, output_type="numpy"
|
||||
).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
f"/compel/forest_{i}.npy"
|
||||
)
|
||||
|
||||
assert np.abs(image - expected_image).max() < 1e-3
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -450,7 +450,9 @@ class PipelineTesterMixin:
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass()
|
||||
|
||||
def _test_attention_slicing_forward_pass(self, test_max_difference=True, expected_max_diff=1e-3):
|
||||
def _test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
@@ -474,7 +476,8 @@ class PipelineTesterMixin:
|
||||
max_diff = np.abs(output_with_slicing - output_without_slicing).max()
|
||||
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
|
||||
|
||||
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
|
||||
if test_mean_pixel_difference:
|
||||
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
|
||||
@@ -550,6 +553,32 @@ class PipelineTesterMixin:
|
||||
_ = pipe(**inputs)
|
||||
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
|
||||
if "num_images_per_prompt" not in sig.parameters:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user