Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0dc0f98526 | |||
| 3045fb2763 | |||
| 7b0ba4820a | |||
| 8d5906a331 | |||
| 17470057d2 | |||
| a5b242d30d | |||
| a121e05feb | |||
| 3979aac996 | |||
| 7e6886f5e9 | |||
| a4c91be73b | |||
| 3becd368b1 | |||
| c8fdfe4572 | |||
| bba1c1de15 | |||
| 86ecd4b795 | |||
| bdeff4d64a | |||
| fc1883918f | |||
| f0c74e9a75 | |||
| 4bc157ffa9 | |||
| f2df39fa0e | |||
| 8ecdd3ef65 | |||
| cd8b7507c2 | |||
| 3b641eabe9 | |||
| 703307efcc | |||
| ed8fd38337 | |||
| ca783a0f1f | |||
| beb848e2b6 | |||
| cfc99adf0f | |||
| 807f69b328 | |||
| b811964a7b | |||
| 1bd4c9e93d | |||
| eb2ef31606 | |||
| 5c9dd0af95 | |||
| d0f258206d | |||
| 3eaead0c4a | |||
| 3bf5ce21ad | |||
| 3a9d7d9758 | |||
| e748b3c6e1 | |||
| 46c52f9b96 | |||
| d06e06940b | |||
| 0a73b4d3cd | |||
| e126a82cc5 |
@@ -21,22 +21,27 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch CPU tests on Ubuntu
|
||||
framework: pytorch
|
||||
- name: Fast PyTorch Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu
|
||||
- name: Fast Flax CPU tests on Ubuntu
|
||||
report: torch_cpu_pipelines
|
||||
- name: Fast PyTorch Models & Schedulers CPU tests
|
||||
framework: pytorch_models
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_models_schedulers
|
||||
- name: Fast Flax CPU tests
|
||||
framework: flax
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: Fast ONNXRuntime CPU tests on Ubuntu
|
||||
- name: Fast ONNXRuntime CPU tests
|
||||
framework: onnxruntime
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-onnxruntime-cpu
|
||||
report: onnx_cpu
|
||||
- name: PyTorch Example CPU tests on Ubuntu
|
||||
- name: PyTorch Example CPU tests
|
||||
framework: pytorch_examples
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
@@ -71,13 +76,21 @@ jobs:
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
tests/pipelines
|
||||
|
||||
- name: Run fast PyTorch Model Scheduler CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_models' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
@@ -85,7 +98,7 @@ jobs:
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
tests
|
||||
|
||||
- name: Run fast ONNXRuntime CPU tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
- local: using-diffusers/schedulers
|
||||
title: Load and compare different schedulers
|
||||
- local: using-diffusers/custom_pipeline_overview
|
||||
title: Load and add custom pipelines
|
||||
title: Load community pipelines
|
||||
- local: using-diffusers/kerascv
|
||||
title: Load KerasCV Stable Diffusion checkpoints
|
||||
title: Loading & Hub
|
||||
@@ -47,9 +47,9 @@
|
||||
- local: using-diffusers/reproducibility
|
||||
title: Create reproducible pipelines
|
||||
- local: using-diffusers/custom_pipeline_examples
|
||||
title: Community Pipelines
|
||||
title: Community pipelines
|
||||
- local: using-diffusers/contribute_pipeline
|
||||
title: How to contribute a Pipeline
|
||||
title: How to contribute a community pipeline
|
||||
- local: using-diffusers/using_safetensors
|
||||
title: Using safetensors
|
||||
- local: using-diffusers/stable_diffusion_jax_how_to
|
||||
@@ -74,6 +74,8 @@
|
||||
title: ControlNet
|
||||
- local: training/instructpix2pix
|
||||
title: InstructPix2Pix Training
|
||||
- local: training/custom_diffusion
|
||||
title: Custom Diffusion
|
||||
title: Training
|
||||
- sections:
|
||||
- local: using-diffusers/rl
|
||||
|
||||
@@ -36,3 +36,7 @@ API to load such adapter neural networks via the [`loaders.py` module](https://g
|
||||
### LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.LoraLoaderMixin
|
||||
|
||||
### FromCkptMixin
|
||||
|
||||
[[autodoc]] loaders.FromCkptMixin
|
||||
|
||||
@@ -242,6 +242,42 @@ image.save("./multi_controlnet_output.png")
|
||||
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/controlnet/multi_controlnet_output.png" width=600/>
|
||||
|
||||
### Guess Mode
|
||||
|
||||
Guess Mode is [a ControlNet feature that was implemented](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode) after the publication of [the paper](https://arxiv.org/abs/2302.05543). The description states:
|
||||
|
||||
>In this mode, the ControlNet encoder will try best to recognize the content of the input control map, like depth map, edge map, scribbles, etc, even if you remove all prompts.
|
||||
|
||||
#### The core implementation:
|
||||
|
||||
It adjusts the scale of the output residuals from ControlNet by a fixed ratio depending on the block depth. The shallowest DownBlock corresponds to `0.1`. As the blocks get deeper, the scale increases exponentially, and the scale for the output of the MidBlock becomes `1.0`.
|
||||
|
||||
Since the core implementation is just this, **it does not have any impact on prompt conditioning**. While it is common to use it without specifying any prompts, it is also possible to provide prompts if desired.
|
||||
|
||||
#### Usage:
|
||||
|
||||
Just specify `guess_mode=True` in the pipe() function. A `guidance_scale` between 3.0 and 5.0 is [recommended](https://github.com/lllyasviel/ControlNet#guess-mode--non-prompt-mode).
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
import torch
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet).to(
|
||||
"cuda"
|
||||
)
|
||||
image = pipe("", image=canny_image, guess_mode=True, guidance_scale=3.0).images[0]
|
||||
image.save("guess_mode_generated.png")
|
||||
```
|
||||
|
||||
#### Output image comparison:
|
||||
Canny Control Example
|
||||
|
||||
|no guess_mode with prompt|guess_mode without prompt|
|
||||
|---|---|
|
||||
|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0.png"><img width="128" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0.png"/></a>|<a href="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0_gm.png"><img width="128" src="https://huggingface.co/takuma104/controlnet_dev/resolve/main/gen_compare_guess_mode/output_images/diffusers/output_bird_canny_0_gm.png"/></a>|
|
||||
|
||||
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
ControlNet requires a *control image* in addition to the text-to-image *prompt*.
|
||||
@@ -272,6 +308,7 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
|
||||
## FlaxStableDiffusionControlNetPipeline
|
||||
[[autodoc]] FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
@@ -30,4 +30,7 @@ Available Checkpoints are:
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
@@ -30,7 +30,11 @@ proposed by Chenlin Meng, Yutong He, Yang Song, Jiaming Song, Jiajun Wu, Jun-Yan
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- from_ckpt
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
[[autodoc]] FlaxStableDiffusionImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
- __call__
|
||||
|
||||
@@ -31,7 +31,10 @@ Available checkpoints are:
|
||||
- disable_attention_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- load_textual_inversion
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
[[autodoc]] FlaxStableDiffusionInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
- __call__
|
||||
|
||||
@@ -68,3 +68,6 @@ images[0].save("snowy_mountains.png")
|
||||
[[autodoc]] StableDiffusionInstructPix2PixPipeline
|
||||
- __call__
|
||||
- all
|
||||
- load_textual_inversion
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
@@ -39,6 +39,10 @@ Available Checkpoints are:
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- enable_vae_tiling
|
||||
- disable_vae_tiling
|
||||
- load_textual_inversion
|
||||
- from_ckpt
|
||||
- load_lora_weights
|
||||
- save_lora_weights
|
||||
|
||||
[[autodoc]] FlaxStableDiffusionPipeline
|
||||
- all
|
||||
|
||||
@@ -113,6 +113,29 @@ accelerate launch train_controlnet.py \
|
||||
--gradient_accumulation_steps=4
|
||||
```
|
||||
|
||||
## Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="path to save model"
|
||||
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_controlnet.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=fusing/fill50k \
|
||||
--resolution=512 \
|
||||
--learning_rate=1e-5 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--train_batch_size=4 \
|
||||
--mixed_precision="fp16" \
|
||||
--tracker_project_name="controlnet-demo" \
|
||||
--report_to=wandb
|
||||
```
|
||||
|
||||
## Example results
|
||||
|
||||
#### After 300 steps with batch size 8
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
<!--Copyright 2023 Custom Diffusion authors The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Custom Diffusion training example
|
||||
|
||||
[Custom Diffusion](https://arxiv.org/abs/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject.
|
||||
The `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install clip-retrieval
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
```
|
||||
### Cat example 😺
|
||||
|
||||
Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.
|
||||
|
||||
We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
|
||||
The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
|
||||
```
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="./data/cat"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>"
|
||||
```
|
||||
|
||||
**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
|
||||
|
||||
To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps:
|
||||
|
||||
* Install `wandb`: `pip install wandb`.
|
||||
* Authorize: `wandb login`.
|
||||
* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments:
|
||||
* `num_validation_images`
|
||||
* `validation_steps`
|
||||
|
||||
Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>" \
|
||||
--validation_prompt="<new1> cat sitting in a bucket" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details.
|
||||
|
||||
If you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat).
|
||||
|
||||
### Training on multiple concepts 🐱🪵
|
||||
|
||||
Provide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py).
|
||||
|
||||
To collect the real images run this command for each concept in the json file.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
|
||||
```
|
||||
|
||||
And then we're ready to start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--concepts_list=./concept_list.json \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--num_class_images=200 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>+<new2>"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details.
|
||||
|
||||
### Training on human faces
|
||||
|
||||
For fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images.
|
||||
|
||||
To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
|
||||
```
|
||||
|
||||
Then start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="path-to-images"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_person/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="person" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> person" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=1000 \
|
||||
--scale_lr --hflip --noaug \
|
||||
--freeze_model crossattn \
|
||||
--modifier_token "<new1>" \
|
||||
--enable_xformers_memory_efficient_attention
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \<new1\> in above example) in your prompt.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
|
||||
pipe.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion("path-to-save-model", weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
It's possible to directly load these parameters from a Hub repository:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
Here is an example of performing inference with multiple concepts:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new2>.bin")
|
||||
|
||||
image = pipe(
|
||||
"the <new1> cat sculpture in the style of a <new2> wooden pot",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("multi-subject.png")
|
||||
```
|
||||
|
||||
Here, `cat` and `wooden pot` refer to the multiple concepts.
|
||||
|
||||
### Inference from a training checkpoint
|
||||
|
||||
You can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument.
|
||||
|
||||
TODO.
|
||||
|
||||
## Set grads to none
|
||||
|
||||
To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
|
||||
|
||||
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
|
||||
|
||||
## Experimental results
|
||||
|
||||
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail.
|
||||
@@ -60,7 +60,18 @@ DreamBooth finetuning is very sensitive to hyperparameters and easy to overfit.
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
Let's try DreamBooth with a [few images of a dog](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ); download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:
|
||||
Let's try DreamBooth with a
|
||||
[few images of a dog](https://huggingface.co/datasets/diffusers/dog-example);
|
||||
download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:
|
||||
|
||||
```python
|
||||
local_dir = "./path_to_training_images"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
|
||||
@@ -126,6 +126,27 @@ accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
|
||||
|
||||
***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
|
||||
|
||||
## Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \
|
||||
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
|
||||
--dataset_name=sayakpaul/instructpix2pix-1000-samples \
|
||||
--use_ema \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
--resolution=512 --random_flip \
|
||||
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
|
||||
--max_train_steps=15000 \
|
||||
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
|
||||
--learning_rate=5e-05 --lr_warmup_steps=0 \
|
||||
--conditioning_dropout_prob=0.05 \
|
||||
--mixed_precision=fp16 \
|
||||
--seed=42
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once training is complete, we can perform inference:
|
||||
|
||||
@@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`].
|
||||
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also
|
||||
support LoRA fine-tuning of the text encoder for DreamBooth in a limited capacity. For more details on how we support
|
||||
LoRA fine-tuning of the text encoder, refer to the discussion on [this PR](https://github.com/huggingface/diffusers/pull/2918).
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -175,6 +177,11 @@ accelerate launch train_dreambooth_lora.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
It's also possible to additionally fine-tune the text encoder with LoRA. This, in most cases, leads
|
||||
to better results with a slight increase in the compute. To allow fine-tuning the text encoder with LoRA,
|
||||
specify the `--train_text_encoder` while launching the `train_dreambooth_lora.py` script.
|
||||
|
||||
|
||||
### Inference[[dreambooth-inference]]
|
||||
|
||||
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]:
|
||||
|
||||
@@ -39,6 +39,8 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
|
||||
- [Dreambooth](./dreambooth)
|
||||
- [LoRA Support](./lora)
|
||||
- [ControlNet](./controlnet)
|
||||
- [InstructPix2Pix](./instructpix2pix)
|
||||
- [Custom Diffusion](./custom_diffusion)
|
||||
|
||||
If possible, please [install xFormers](../optimization/xformers) for memory efficient attention. This could help make your training faster and less memory intensive.
|
||||
|
||||
@@ -50,6 +52,8 @@ If possible, please [install xFormers](../optimization/xformers) for memory effi
|
||||
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
|
||||
| [**Training with LoRA**](./lora) | ✅ | - | - |
|
||||
| [**ControlNet**](./controlnet) | ✅ | ✅ | - |
|
||||
| [**InstructPix2Pix**](./instructpix2pix) | ✅ | ✅ | - |
|
||||
| [**Custom Diffusion**](./custom_diffusion) | ✅ | ✅ | - |
|
||||
|
||||
## Community
|
||||
|
||||
|
||||
@@ -106,6 +106,31 @@ accelerate launch train_text_to_image.py \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--output_dir=${OUTPUT_DIR}
|
||||
```
|
||||
|
||||
#### Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export dataset_name="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$dataset_name \
|
||||
--use_ema \
|
||||
--resolution=512 --center_crop --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--output_dir="sd-pokemon-model"
|
||||
```
|
||||
|
||||
</pt>
|
||||
<jax>
|
||||
With Flax, it's possible to train a Stable Diffusion model faster on TPUs and GPUs thanks to [@duongna211](https://github.com/duongna21). This is very efficient on TPU hardware but works great on GPUs too. The Flax training script doesn't support features like gradient checkpointing or gradient accumulation yet, so you'll need a GPU with at least 30GB of memory or a TPU v3.
|
||||
|
||||
@@ -122,6 +122,26 @@ accelerate launch train_unconditional.py \
|
||||
<img src="https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png"/>
|
||||
</div>
|
||||
|
||||
### Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
|
||||
--dataset_name="huggan/pokemon" \
|
||||
--resolution=64 --center_crop --random_flip \
|
||||
--output_dir="ddpm-ema-pokemon-64" \
|
||||
--train_batch_size=16 \
|
||||
--num_epochs=100 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--use_ema \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_warmup_steps=500 \
|
||||
--mixed_precision="fp16" \
|
||||
--logger="wandb"
|
||||
```
|
||||
|
||||
## Finetuning with your own data
|
||||
|
||||
There are two ways to finetune a model on your own dataset:
|
||||
|
||||
@@ -10,17 +10,21 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# How to build a community pipeline
|
||||
# How to contribute a community pipeline
|
||||
|
||||
*Note*: this page was built from the GitHub Issue on Community Pipelines [#841](https://github.com/huggingface/diffusers/issues/841).
|
||||
<Tip>
|
||||
|
||||
Let's make an example!
|
||||
Say you want to define a pipeline that just does a single forward pass to a U-Net and then calls a scheduler only once (Note, this doesn't make any sense from a scientific point of view, but only represents an example of how things work under the hood).
|
||||
💡 Take a look at GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841) for more context about why we're adding community pipelines to help everyone easily share their work without being slowed down.
|
||||
|
||||
Cool! So you open your favorite IDE and start creating your pipeline 💻.
|
||||
First, what model weights and configurations do we need?
|
||||
We have a U-Net and a scheduler, so our pipeline should take a U-Net and a scheduler as an argument.
|
||||
Also, as stated above, you'd like to be able to load weights and the scheduler config for Hub and share your code with others, so we'll inherit from `DiffusionPipeline`:
|
||||
</Tip>
|
||||
|
||||
Community pipelines allow you to add any additional features you'd like on top of the [`DiffusionPipeline`]. The main benefit of building on top of the `DiffusionPipeline` is anyone can load and use your pipeline by only adding one more argument, making it super easy for the community to access.
|
||||
|
||||
This guide will show you how to create a community pipeline and explain how they work. To keep things simple, you'll create a "one-step" pipeline where the `UNet` does a single forward pass and calls the scheduler once.
|
||||
|
||||
## Initialize the pipeline
|
||||
|
||||
You should start by creating a `one_step_unet.py` file for your community pipeline. In this file, create a pipeline class that inherits from the [`DiffusionPipeline`] to be able to load model weights and the scheduler configuration from the Hub. The one-step pipeline needs a `UNet` and a scheduler, so you'll need to add these as arguments to the `__init__` function:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -32,50 +36,52 @@ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
super().__init__()
|
||||
```
|
||||
|
||||
Now, we must save the `unet` and `scheduler` in a config file so that you can save your pipeline with `save_pretrained`.
|
||||
Therefore, make sure you add every component that is save-able to the `register_modules` function:
|
||||
To ensure your pipeline and its components (`unet` and `scheduler`) can be saved with [`~DiffusionPipeline.save_pretrained`], add them to the `register_modules` function:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
+ self.register_modules(unet=unet, scheduler=scheduler)
|
||||
```
|
||||
|
||||
Cool, the init is done! 🔥 Now, let's go into the forward pass, which we recommend defining as `__call__` . Here you're given all the creative freedom there is. For our amazing "one-step" pipeline, we simply create a random image and call the unet once and the scheduler once:
|
||||
Cool, the `__init__` step is done and you can move to the forward pass now! 🔥
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
## Define the forward pass
|
||||
|
||||
In the forward pass, which we recommend defining as `__call__`, you have complete creative freedom to add whatever feature you'd like. For our amazing one-step pipeline, create a random image and only call the `unet` and `scheduler` once by setting `timestep=1`:
|
||||
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
|
||||
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
def __call__(self):
|
||||
image = torch.randn(
|
||||
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
)
|
||||
timestep = 1
|
||||
+ def __call__(self):
|
||||
+ image = torch.randn(
|
||||
+ (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
+ )
|
||||
+ timestep = 1
|
||||
|
||||
model_output = self.unet(image, timestep).sample
|
||||
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
|
||||
+ model_output = self.unet(image, timestep).sample
|
||||
+ scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
|
||||
|
||||
return scheduler_output
|
||||
+ return scheduler_output
|
||||
```
|
||||
|
||||
Cool, that's it! 🚀 You can now run this pipeline by passing a `unet` and a `scheduler` to the init:
|
||||
That's it! 🚀 You can now run this pipeline by passing a `unet` and `scheduler` to it:
|
||||
|
||||
```python
|
||||
from diffusers import DDPMScheduler, Unet2DModel
|
||||
from diffusers import DDPMScheduler, UNet2DModel
|
||||
|
||||
scheduler = DDPMScheduler()
|
||||
unet = UNet2DModel()
|
||||
@@ -85,7 +91,7 @@ pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)
|
||||
output = pipeline()
|
||||
```
|
||||
|
||||
But what's even better is that you can load pre-existing weights into the pipeline if they match exactly your pipeline structure. This is e.g. the case for [https://huggingface.co/google/ddpm-cifar10-32](https://huggingface.co/google/ddpm-cifar10-32) so that we can do the following:
|
||||
But what's even better is you can load pre-existing weights into the pipeline if the pipeline structure is identical. For example, you can load the [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32) weights into the one-step pipeline:
|
||||
|
||||
```python
|
||||
pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32")
|
||||
@@ -93,33 +99,11 @@ pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-
|
||||
output = pipeline()
|
||||
```
|
||||
|
||||
We want to share this amazing pipeline with the community, so we would open a PR request to add the following code under `one_step_unet.py` to [https://github.com/huggingface/diffusers/tree/main/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) .
|
||||
## Share your pipeline
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
Open a Pull Request on the 🧨 Diffusers [repository](https://github.com/huggingface/diffusers) to add your awesome pipeline in `one_step_unet.py` to the [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) subfolder.
|
||||
|
||||
|
||||
class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
def __call__(self):
|
||||
image = torch.randn(
|
||||
(1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
|
||||
)
|
||||
timestep = 1
|
||||
|
||||
model_output = self.unet(image, timestep).sample
|
||||
scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
|
||||
|
||||
return scheduler_output
|
||||
```
|
||||
|
||||
Our amazing pipeline got merged here: [#840](https://github.com/huggingface/diffusers/pull/840).
|
||||
Now everybody that has `diffusers >= 0.4.0` installed can use our pipeline magically 🪄 as follows:
|
||||
Once it is merged, anyone with `diffusers >= 0.4.0` installed can use this pipeline magically 🪄 by specifying it in the `custom_pipeline` argument:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -128,28 +112,59 @@ pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeli
|
||||
pipe()
|
||||
```
|
||||
|
||||
Another way to upload your custom_pipeline, besides sending a PR, is uploading the code that contains it to the Hugging Face Hub, [as exemplified here](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview#loading-custom-pipelines-from-the-hub).
|
||||
Another way to share your community pipeline is to upload the `one_step_unet.py` file directly to your preferred [model repository](https://huggingface.co/docs/hub/models-uploading) on the Hub. Instead of specifying the `one_step_unet.py` file, pass the model repository id to the `custom_pipeline` argument:
|
||||
|
||||
**Try it out now - it works!**
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
In general, you will want to create much more sophisticated pipelines, so we recommend looking at existing pipelines here: [https://github.com/huggingface/diffusers/tree/main/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community).
|
||||
pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="stevhliu/one_step_unet")
|
||||
```
|
||||
|
||||
IMPORTANT:
|
||||
You can use whatever package you want in your community pipeline file - as long as the user has it installed, everything will work fine. Make sure you have one and only one pipeline class that inherits from `DiffusionPipeline` as this will be automatically detected.
|
||||
Take a look at the following table to compare the two sharing workflows to help you decide the best option for you:
|
||||
|
||||
| | GitHub community pipeline | HF Hub community pipeline |
|
||||
|----------------|------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|
|
||||
| usage | same | same |
|
||||
| review process | open a Pull Request on GitHub and undergo a review process from the Diffusers team before merging; may be slower | upload directly to a Hub repository without any review; this is the fastest workflow |
|
||||
| visibility | included in the official Diffusers repository and documentation | included on your HF Hub profile and relies on your own usage/promotion to gain visibility |
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 You can use whatever package you want in your community pipeline file - as long as the user has it installed, everything will work fine. Make sure you have one and only one pipeline class that inherits from `DiffusionPipeline` because this is automatically detected.
|
||||
|
||||
</Tip>
|
||||
|
||||
## How do community pipelines work?
|
||||
A community pipeline is a class that has to inherit from ['DiffusionPipeline']:
|
||||
and that has been added to `examples/community` [files](https://github.com/huggingface/diffusers/tree/main/examples/community).
|
||||
The community can load the pipeline code via the custom_pipeline argument from DiffusionPipeline. See docs [here](https://huggingface.co/docs/diffusers/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.custom_pipeline):
|
||||
|
||||
This means:
|
||||
The model weights and configs of the pipeline should be loaded from the `pretrained_model_name_or_path` [argument](https://huggingface.co/docs/diffusers/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path):
|
||||
whereas the code that powers the community pipeline is defined in a file added in [`examples/community`](https://github.com/huggingface/diffusers/tree/main/examples/community).
|
||||
A community pipeline is a class that inherits from [`DiffusionPipeline`] which means:
|
||||
|
||||
Now, it might very well be that only some of your pipeline components weights can be downloaded from an official repo.
|
||||
The other components should then be passed directly to init as is the case for the ClIP guidance notebook [here](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb#scrollTo=z9Kglma6hjki).
|
||||
- It can be loaded with the [`custom_pipeline`] argument.
|
||||
- The model weights and scheduler configuration are loaded from [`pretrained_model_name_or_path`].
|
||||
- The code that implements a feature in the community pipeline is defined in a `pipeline.py` file.
|
||||
|
||||
The magic behind all of this is that we load the code directly from GitHub. You can check it out in more detail if you follow the functionality defined here:
|
||||
Sometimes you can't load all the pipeline components weights from an official repository. In this case, the other components should be passed directly to the pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
|
||||
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
|
||||
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
custom_pipeline="clip_guided_stable_diffusion",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
scheduler=scheduler,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
The magic behind community pipelines is contained in the following code. It allows the community pipeline to be loaded from GitHub or the Hub, and it'll be available to all 🧨 Diffusers packages.
|
||||
|
||||
```python
|
||||
# 2. Load the pipeline class, if using custom module then load it from the hub
|
||||
@@ -164,6 +179,3 @@ else:
|
||||
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
```
|
||||
|
||||
This is why a community pipeline merged to GitHub will be directly available to all `diffusers` packages.
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Custom Pipelines
|
||||
# Community pipelines
|
||||
|
||||
> **For more information about community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).**
|
||||
|
||||
|
||||
@@ -10,19 +10,21 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Loading and Adding Custom Pipelines
|
||||
# Load community pipelines
|
||||
|
||||
Diffusers allows you to conveniently load any custom pipeline from the Hugging Face Hub as well as any [official community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community)
|
||||
via the [`DiffusionPipeline`] class.
|
||||
Community pipelines are any [`DiffusionPipeline`] class that are different from the original implementation as specified in their paper (for example, the [`StableDiffusionControlNetPipeline`] corresponds to the [Text-to-Image Generation with ControlNet Conditioning](https://arxiv.org/abs/2302.05543) paper). They provide additional functionality or extend the original implementation of a pipeline.
|
||||
|
||||
## Loading custom pipelines from the Hub
|
||||
There are many cool community pipelines like [Speech to Image](https://github.com/huggingface/diffusers/tree/main/examples/community#speech-to-image) or [Composable Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#composable-stable-diffusion), and you can find all the official community pipelines [here](https://github.com/huggingface/diffusers/tree/main/examples/community).
|
||||
|
||||
Custom pipelines can be easily loaded from any model repository on the Hub that defines a diffusion pipeline in a `pipeline.py` file.
|
||||
Let's load a dummy pipeline from [hf-internal-testing/diffusers-dummy-pipeline](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline).
|
||||
To load any community pipeline on the Hub, pass the repository id of the community pipeline to the `custom_pipeline` argument and the model repository where you'd like to load the pipeline weights and components from. For example, the example below loads a dummy pipeline from [`hf-internal-testing/diffusers-dummy-pipeline`](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py) and the pipeline weights and components from [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32):
|
||||
|
||||
All you need to do is pass the custom pipeline repo id with the `custom_pipeline` argument alongside the repo from where you wish to load the pipeline modules.
|
||||
<Tip warning={true}>
|
||||
|
||||
```python
|
||||
🔒 By loading a community pipeline from the Hugging Face Hub, you are trusting that the code you are loading is safe. Make sure to inspect the code online before loading and running it automatically!
|
||||
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
@@ -30,25 +32,9 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
This will load the custom pipeline as defined in the [model repository](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py).
|
||||
Loading an official community pipeline is similar, but you can mix loading weights from an official repository id and pass pipeline components directly. The example below loads the community [CLIP Guided Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#clip-guided-stable-diffusion) pipeline, and you can pass the CLIP model components directly to it:
|
||||
|
||||
<Tip warning={true} >
|
||||
|
||||
By loading a custom pipeline from the Hugging Face Hub, you are trusting that the code you are loading
|
||||
is safe 🔒. Make sure to check out the code online before loading & running it automatically.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Loading official community pipelines
|
||||
|
||||
Community pipelines are summarized in the [community examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community).
|
||||
|
||||
Similarly, you need to pass both the *repo id* from where you wish to load the weights as well as the `custom_pipeline` argument. Here the `custom_pipeline` argument should consist simply of the filename of the community pipeline excluding the `.py` suffix, *e.g.* `clip_guided_stable_diffusion`.
|
||||
|
||||
Since community pipelines are often more complex, one can mix loading weights from an official *repo id*
|
||||
and passing pipeline modules directly.
|
||||
|
||||
```python
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import CLIPImageProcessor, CLIPModel
|
||||
|
||||
@@ -65,59 +51,4 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
## Adding custom pipelines to the Hub
|
||||
|
||||
To add a custom pipeline to the Hub, all you need to do is to define a pipeline class that inherits
|
||||
from [`DiffusionPipeline`] in a `pipeline.py` file.
|
||||
Make sure that the whole pipeline is encapsulated within a single class and that the `pipeline.py` file
|
||||
has only one such class.
|
||||
|
||||
Let's quickly define an example pipeline.
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class MyPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size: int = 1, num_inference_steps: int = 50):
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size)
|
||||
)
|
||||
|
||||
image = image.to(self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
return image
|
||||
```
|
||||
|
||||
Now you can upload this short file under the name `pipeline.py` in your preferred [model repository](https://huggingface.co/docs/hub/models-uploading). For Stable Diffusion pipelines, you may also [join the community organisation for shared pipelines](https://huggingface.co/organizations/sd-diffusers-pipelines-library/share/BUPyDUuHcciGTOKaExlqtfFcyCZsVFdrjr) to upload yours.
|
||||
Finally, we can load the custom pipeline by passing the model repository name, *e.g.* `sd-diffusers-pipelines-library/my_custom_pipeline` alongside the model repository from where we want to load the `unet` and `scheduler` components.
|
||||
|
||||
```python
|
||||
my_pipeline = DiffusionPipeline.from_pretrained(
|
||||
"google/ddpm-cifar10-32", custom_pipeline="patrickvonplaten/my_custom_pipeline"
|
||||
)
|
||||
```
|
||||
For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
|
||||
@@ -31,7 +31,7 @@ MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt
|
||||
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - |[Aengus (Duc-Anh)](https://github.com/aengusng8) |
|
||||
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
|
||||
|
||||
| TensorRT Stable Diffusion Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - |[Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
@@ -1130,3 +1130,34 @@ Init Image
|
||||
Output Image
|
||||
|
||||

|
||||
|
||||
### TensorRT Text2Image Stable Diffusion Pipeline
|
||||
|
||||
The TensorRT Pipeline can be used to accelerate the Text2Image Stable Diffusion Inference run.
|
||||
|
||||
NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DDIMScheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
|
||||
|
||||
# Use the DDIMScheduler scheduler here instead
|
||||
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
|
||||
subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
|
||||
custom_pipeline="stable_diffusion_tensorrt_txt2img",
|
||||
revision='fp16',
|
||||
torch_dtype=torch.float16,
|
||||
scheduler=scheduler,)
|
||||
|
||||
# re-use cached folder to save ONNX models and TensorRT Engines
|
||||
pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision='fp16',)
|
||||
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "a beautiful photograph of Mt. Fuji during cherry blossom"
|
||||
image = pipe(prompt).images[0]
|
||||
image.save('tensorrt_mt_fuji.png')
|
||||
```
|
||||
|
||||
@@ -0,0 +1,926 @@
|
||||
#
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from copy import copy
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnx_graphsurgeon as gs
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from onnx import shape_inference
|
||||
from polygraphy import cuda
|
||||
from polygraphy.backend.common import bytes_from_path
|
||||
from polygraphy.backend.onnx.loader import fold_constants
|
||||
from polygraphy.backend.trt import (
|
||||
CreateConfig,
|
||||
Profile,
|
||||
engine_from_bytes,
|
||||
engine_from_network,
|
||||
network_from_onnx_path,
|
||||
save_engine,
|
||||
)
|
||||
from polygraphy.backend.trt import util as trt_util
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPipelineOutput,
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import DDIMScheduler
|
||||
from diffusers.utils import DIFFUSERS_CACHE, logging
|
||||
|
||||
|
||||
"""
|
||||
Installation instructions
|
||||
python3 -m pip install --upgrade tensorrt
|
||||
python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
|
||||
python3 -m pip install onnxruntime
|
||||
"""
|
||||
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Map of numpy dtype -> torch dtype
|
||||
numpy_to_torch_dtype_dict = {
|
||||
np.uint8: torch.uint8,
|
||||
np.int8: torch.int8,
|
||||
np.int16: torch.int16,
|
||||
np.int32: torch.int32,
|
||||
np.int64: torch.int64,
|
||||
np.float16: torch.float16,
|
||||
np.float32: torch.float32,
|
||||
np.float64: torch.float64,
|
||||
np.complex64: torch.complex64,
|
||||
np.complex128: torch.complex128,
|
||||
}
|
||||
if np.version.full_version >= "1.24.0":
|
||||
numpy_to_torch_dtype_dict[np.bool_] = torch.bool
|
||||
else:
|
||||
numpy_to_torch_dtype_dict[np.bool] = torch.bool
|
||||
|
||||
# Map of torch dtype -> numpy dtype
|
||||
torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
||||
|
||||
|
||||
def device_view(t):
|
||||
return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype])
|
||||
|
||||
|
||||
class Engine:
|
||||
def __init__(self, engine_path):
|
||||
self.engine_path = engine_path
|
||||
self.engine = None
|
||||
self.context = None
|
||||
self.buffers = OrderedDict()
|
||||
self.tensors = OrderedDict()
|
||||
|
||||
def __del__(self):
|
||||
[buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)]
|
||||
del self.engine
|
||||
del self.context
|
||||
del self.buffers
|
||||
del self.tensors
|
||||
|
||||
def build(
|
||||
self,
|
||||
onnx_path,
|
||||
fp16,
|
||||
input_profile=None,
|
||||
enable_preview=False,
|
||||
enable_all_tactics=False,
|
||||
timing_cache=None,
|
||||
workspace_size=0,
|
||||
):
|
||||
logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
|
||||
p = Profile()
|
||||
if input_profile:
|
||||
for name, dims in input_profile.items():
|
||||
assert len(dims) == 3
|
||||
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
|
||||
|
||||
config_kwargs = {}
|
||||
|
||||
config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805]
|
||||
if enable_preview:
|
||||
# Faster dynamic shapes made optional since it increases engine build time.
|
||||
config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805)
|
||||
if workspace_size > 0:
|
||||
config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size}
|
||||
if not enable_all_tactics:
|
||||
config_kwargs["tactic_sources"] = []
|
||||
|
||||
engine = engine_from_network(
|
||||
network_from_onnx_path(onnx_path),
|
||||
config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs),
|
||||
save_timing_cache=timing_cache,
|
||||
)
|
||||
save_engine(engine, path=self.engine_path)
|
||||
|
||||
def load(self):
|
||||
logger.warning(f"Loading TensorRT engine: {self.engine_path}")
|
||||
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
|
||||
|
||||
def activate(self):
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
def allocate_buffers(self, shape_dict=None, device="cuda"):
|
||||
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
|
||||
binding = self.engine[idx]
|
||||
if shape_dict and binding in shape_dict:
|
||||
shape = shape_dict[binding]
|
||||
else:
|
||||
shape = self.engine.get_binding_shape(binding)
|
||||
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
|
||||
if self.engine.binding_is_input(binding):
|
||||
self.context.set_binding_shape(idx, shape)
|
||||
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
|
||||
self.tensors[binding] = tensor
|
||||
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
|
||||
|
||||
def infer(self, feed_dict, stream):
|
||||
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
|
||||
# shallow copy of ordered dict
|
||||
device_buffers = copy(self.buffers)
|
||||
for name, buf in feed_dict.items():
|
||||
assert isinstance(buf, cuda.DeviceView)
|
||||
device_buffers[name] = buf
|
||||
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
|
||||
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
|
||||
if not noerror:
|
||||
raise ValueError("ERROR: inference failed.")
|
||||
|
||||
return self.tensors
|
||||
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, onnx_graph):
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
|
||||
def cleanup(self, return_onnx=False):
|
||||
self.graph.cleanup().toposort()
|
||||
if return_onnx:
|
||||
return gs.export_onnx(self.graph)
|
||||
|
||||
def select_outputs(self, keep, names=None):
|
||||
self.graph.outputs = [self.graph.outputs[o] for o in keep]
|
||||
if names:
|
||||
for i, name in enumerate(names):
|
||||
self.graph.outputs[i].name = name
|
||||
|
||||
def fold_constants(self, return_onnx=False):
|
||||
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
if return_onnx:
|
||||
return onnx_graph
|
||||
|
||||
def infer_shapes(self, return_onnx=False):
|
||||
onnx_graph = gs.export_onnx(self.graph)
|
||||
if onnx_graph.ByteSize() > 2147483648:
|
||||
raise TypeError("ERROR: model size exceeds supported 2GB limit")
|
||||
else:
|
||||
onnx_graph = shape_inference.infer_shapes(onnx_graph)
|
||||
|
||||
self.graph = gs.import_onnx(onnx_graph)
|
||||
if return_onnx:
|
||||
return onnx_graph
|
||||
|
||||
|
||||
class BaseModel:
|
||||
def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77):
|
||||
self.model = model
|
||||
self.name = "SD Model"
|
||||
self.fp16 = fp16
|
||||
self.device = device
|
||||
|
||||
self.min_batch = 1
|
||||
self.max_batch = max_batch_size
|
||||
self.min_image_shape = 256 # min image resolution: 256x256
|
||||
self.max_image_shape = 1024 # max image resolution: 1024x1024
|
||||
self.min_latent_shape = self.min_image_shape // 8
|
||||
self.max_latent_shape = self.max_image_shape // 8
|
||||
|
||||
self.embedding_dim = embedding_dim
|
||||
self.text_maxlen = text_maxlen
|
||||
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
def get_input_names(self):
|
||||
pass
|
||||
|
||||
def get_output_names(self):
|
||||
pass
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return None
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
pass
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
||||
return None
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
return None
|
||||
|
||||
def optimize(self, onnx_graph):
|
||||
opt = Optimizer(onnx_graph)
|
||||
opt.cleanup()
|
||||
opt.fold_constants()
|
||||
opt.infer_shapes()
|
||||
onnx_opt_graph = opt.cleanup(return_onnx=True)
|
||||
return onnx_opt_graph
|
||||
|
||||
def check_dims(self, batch_size, image_height, image_width):
|
||||
assert batch_size >= self.min_batch and batch_size <= self.max_batch
|
||||
assert image_height % 8 == 0 or image_width % 8 == 0
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
|
||||
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
|
||||
return (latent_height, latent_width)
|
||||
|
||||
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape):
|
||||
min_batch = batch_size if static_batch else self.min_batch
|
||||
max_batch = batch_size if static_batch else self.max_batch
|
||||
latent_height = image_height // 8
|
||||
latent_width = image_width // 8
|
||||
min_image_height = image_height if static_shape else self.min_image_shape
|
||||
max_image_height = image_height if static_shape else self.max_image_shape
|
||||
min_image_width = image_width if static_shape else self.min_image_shape
|
||||
max_image_width = image_width if static_shape else self.max_image_shape
|
||||
min_latent_height = latent_height if static_shape else self.min_latent_shape
|
||||
max_latent_height = latent_height if static_shape else self.max_latent_shape
|
||||
min_latent_width = latent_width if static_shape else self.min_latent_shape
|
||||
max_latent_width = latent_width if static_shape else self.max_latent_shape
|
||||
return (
|
||||
min_batch,
|
||||
max_batch,
|
||||
min_image_height,
|
||||
max_image_height,
|
||||
min_image_width,
|
||||
max_image_width,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
)
|
||||
|
||||
|
||||
def getOnnxPath(model_name, onnx_dir, opt=True):
|
||||
return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx")
|
||||
|
||||
|
||||
def getEnginePath(model_name, engine_dir):
|
||||
return os.path.join(engine_dir, model_name + ".plan")
|
||||
|
||||
|
||||
def build_engines(
|
||||
models: dict,
|
||||
engine_dir,
|
||||
onnx_dir,
|
||||
onnx_opset,
|
||||
opt_image_height,
|
||||
opt_image_width,
|
||||
opt_batch_size=1,
|
||||
force_engine_rebuild=False,
|
||||
static_batch=False,
|
||||
static_shape=True,
|
||||
enable_preview=False,
|
||||
enable_all_tactics=False,
|
||||
timing_cache=None,
|
||||
max_workspace_size=0,
|
||||
):
|
||||
built_engines = {}
|
||||
if not os.path.isdir(onnx_dir):
|
||||
os.makedirs(onnx_dir)
|
||||
if not os.path.isdir(engine_dir):
|
||||
os.makedirs(engine_dir)
|
||||
|
||||
# Export models to ONNX
|
||||
for model_name, model_obj in models.items():
|
||||
engine_path = getEnginePath(model_name, engine_dir)
|
||||
if force_engine_rebuild or not os.path.exists(engine_path):
|
||||
logger.warning("Building Engines...")
|
||||
logger.warning("Engine build can take a while to complete")
|
||||
onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
|
||||
onnx_opt_path = getOnnxPath(model_name, onnx_dir)
|
||||
if force_engine_rebuild or not os.path.exists(onnx_opt_path):
|
||||
if force_engine_rebuild or not os.path.exists(onnx_path):
|
||||
logger.warning(f"Exporting model: {onnx_path}")
|
||||
model = model_obj.get_model()
|
||||
with torch.inference_mode(), torch.autocast("cuda"):
|
||||
inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
inputs,
|
||||
onnx_path,
|
||||
export_params=True,
|
||||
opset_version=onnx_opset,
|
||||
do_constant_folding=True,
|
||||
input_names=model_obj.get_input_names(),
|
||||
output_names=model_obj.get_output_names(),
|
||||
dynamic_axes=model_obj.get_dynamic_axes(),
|
||||
)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
else:
|
||||
logger.warning(f"Found cached model: {onnx_path}")
|
||||
|
||||
# Optimize onnx
|
||||
if force_engine_rebuild or not os.path.exists(onnx_opt_path):
|
||||
logger.warning(f"Generating optimizing model: {onnx_opt_path}")
|
||||
onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path))
|
||||
onnx.save(onnx_opt_graph, onnx_opt_path)
|
||||
else:
|
||||
logger.warning(f"Found cached optimized model: {onnx_opt_path} ")
|
||||
|
||||
# Build TensorRT engines
|
||||
for model_name, model_obj in models.items():
|
||||
engine_path = getEnginePath(model_name, engine_dir)
|
||||
engine = Engine(engine_path)
|
||||
onnx_path = getOnnxPath(model_name, onnx_dir, opt=False)
|
||||
onnx_opt_path = getOnnxPath(model_name, onnx_dir)
|
||||
|
||||
if force_engine_rebuild or not os.path.exists(engine.engine_path):
|
||||
engine.build(
|
||||
onnx_opt_path,
|
||||
fp16=True,
|
||||
input_profile=model_obj.get_input_profile(
|
||||
opt_batch_size,
|
||||
opt_image_height,
|
||||
opt_image_width,
|
||||
static_batch=static_batch,
|
||||
static_shape=static_shape,
|
||||
),
|
||||
enable_preview=enable_preview,
|
||||
timing_cache=timing_cache,
|
||||
workspace_size=max_workspace_size,
|
||||
)
|
||||
built_engines[model_name] = engine
|
||||
|
||||
# Load and activate TensorRT engines
|
||||
for model_name, model_obj in models.items():
|
||||
engine = built_engines[model_name]
|
||||
engine.load()
|
||||
engine.activate()
|
||||
|
||||
return built_engines
|
||||
|
||||
|
||||
def runEngine(engine, feed_dict, stream):
|
||||
return engine.infer(feed_dict, stream)
|
||||
|
||||
|
||||
class CLIP(BaseModel):
|
||||
def __init__(self, model, device, max_batch_size, embedding_dim):
|
||||
super(CLIP, self).__init__(
|
||||
model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
|
||||
)
|
||||
self.name = "CLIP"
|
||||
|
||||
def get_input_names(self):
|
||||
return ["input_ids"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["text_embeddings", "pooler_output"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
|
||||
batch_size, image_height, image_width, static_batch, static_shape
|
||||
)
|
||||
return {
|
||||
"input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"input_ids": (batch_size, self.text_maxlen),
|
||||
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
self.check_dims(batch_size, image_height, image_width)
|
||||
return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
def optimize(self, onnx_graph):
|
||||
opt = Optimizer(onnx_graph)
|
||||
opt.select_outputs([0]) # delete graph output#1
|
||||
opt.cleanup()
|
||||
opt.fold_constants()
|
||||
opt.infer_shapes()
|
||||
opt.select_outputs([0], names=["text_embeddings"]) # rename network output
|
||||
opt_onnx_graph = opt.cleanup(return_onnx=True)
|
||||
return opt_onnx_graph
|
||||
|
||||
|
||||
def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False):
|
||||
return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
|
||||
|
||||
|
||||
class UNet(BaseModel):
|
||||
def __init__(
|
||||
self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4
|
||||
):
|
||||
super(UNet, self).__init__(
|
||||
model=model,
|
||||
fp16=fp16,
|
||||
device=device,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
text_maxlen=text_maxlen,
|
||||
)
|
||||
self.unet_dim = unet_dim
|
||||
self.name = "UNet"
|
||||
|
||||
def get_input_names(self):
|
||||
return ["sample", "timestep", "encoder_hidden_states"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["latent"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
"sample": {0: "2B", 2: "H", 3: "W"},
|
||||
"encoder_hidden_states": {0: "2B"},
|
||||
"latent": {0: "2B", 2: "H", 3: "W"},
|
||||
}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
(
|
||||
min_batch,
|
||||
max_batch,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
|
||||
return {
|
||||
"sample": [
|
||||
(2 * min_batch, self.unet_dim, min_latent_height, min_latent_width),
|
||||
(2 * batch_size, self.unet_dim, latent_height, latent_width),
|
||||
(2 * max_batch, self.unet_dim, max_latent_height, max_latent_width),
|
||||
],
|
||||
"encoder_hidden_states": [
|
||||
(2 * min_batch, self.text_maxlen, self.embedding_dim),
|
||||
(2 * batch_size, self.text_maxlen, self.embedding_dim),
|
||||
(2 * max_batch, self.text_maxlen, self.embedding_dim),
|
||||
],
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"sample": (2 * batch_size, self.unet_dim, latent_height, latent_width),
|
||||
"encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim),
|
||||
"latent": (2 * batch_size, 4, latent_height, latent_width),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
dtype = torch.float16 if self.fp16 else torch.float32
|
||||
return (
|
||||
torch.randn(
|
||||
2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device
|
||||
),
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
|
||||
)
|
||||
|
||||
|
||||
def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False):
|
||||
return UNet(
|
||||
model,
|
||||
fp16=True,
|
||||
device=device,
|
||||
max_batch_size=max_batch_size,
|
||||
embedding_dim=embedding_dim,
|
||||
unet_dim=(9 if inpaint else 4),
|
||||
)
|
||||
|
||||
|
||||
class VAE(BaseModel):
|
||||
def __init__(self, model, device, max_batch_size, embedding_dim):
|
||||
super(VAE, self).__init__(
|
||||
model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim
|
||||
)
|
||||
self.name = "VAE decoder"
|
||||
|
||||
def get_input_names(self):
|
||||
return ["latent"]
|
||||
|
||||
def get_output_names(self):
|
||||
return ["images"]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
|
||||
|
||||
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
(
|
||||
min_batch,
|
||||
max_batch,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
min_latent_height,
|
||||
max_latent_height,
|
||||
min_latent_width,
|
||||
max_latent_width,
|
||||
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape)
|
||||
return {
|
||||
"latent": [
|
||||
(min_batch, 4, min_latent_height, min_latent_width),
|
||||
(batch_size, 4, latent_height, latent_width),
|
||||
(max_batch, 4, max_latent_height, max_latent_width),
|
||||
]
|
||||
}
|
||||
|
||||
def get_shape_dict(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return {
|
||||
"latent": (batch_size, 4, latent_height, latent_width),
|
||||
"images": (batch_size, 3, image_height, image_width),
|
||||
}
|
||||
|
||||
def get_sample_input(self, batch_size, image_height, image_width):
|
||||
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
||||
return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device)
|
||||
|
||||
|
||||
def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False):
|
||||
return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim)
|
||||
|
||||
|
||||
class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using TensorRT accelerated Stable Diffusion.
|
||||
|
||||
This model inherits from [`StableDiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
stages=["clip", "unet", "vae"],
|
||||
image_height: int = 768,
|
||||
image_width: int = 768,
|
||||
max_batch_size: int = 16,
|
||||
# ONNX export parameters
|
||||
onnx_opset: int = 17,
|
||||
onnx_dir: str = "onnx",
|
||||
# TensorRT engine build parameters
|
||||
engine_dir: str = "engine",
|
||||
force_engine_rebuild: bool = False,
|
||||
timing_cache: str = "timing_cache",
|
||||
):
|
||||
super().__init__(
|
||||
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
|
||||
)
|
||||
|
||||
self.vae.forward = self.vae.decode
|
||||
|
||||
self.stages = stages
|
||||
self.image_height, self.image_width = image_height, image_width
|
||||
self.inpaint = False
|
||||
self.onnx_opset = onnx_opset
|
||||
self.onnx_dir = onnx_dir
|
||||
self.engine_dir = engine_dir
|
||||
self.force_engine_rebuild = force_engine_rebuild
|
||||
self.timing_cache = timing_cache
|
||||
self.build_static_batch = False
|
||||
self.build_dynamic_shape = False
|
||||
self.build_preview_features = False
|
||||
|
||||
self.max_batch_size = max_batch_size
|
||||
# TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation.
|
||||
if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512:
|
||||
self.max_batch_size = 4
|
||||
|
||||
self.stream = None # loaded in loadResources()
|
||||
self.models = {} # loaded in __loadModels()
|
||||
self.engine = {} # loaded in build_engines()
|
||||
|
||||
def __loadModels(self):
|
||||
# Load pipeline models
|
||||
self.embedding_dim = self.text_encoder.config.hidden_size
|
||||
models_args = {
|
||||
"device": self.torch_device,
|
||||
"max_batch_size": self.max_batch_size,
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"inpaint": self.inpaint,
|
||||
}
|
||||
if "clip" in self.stages:
|
||||
self.models["clip"] = make_CLIP(self.text_encoder, **models_args)
|
||||
if "unet" in self.stages:
|
||||
self.models["unet"] = make_UNet(self.unet, **models_args)
|
||||
if "vae" in self.stages:
|
||||
self.models["vae"] = make_VAE(self.vae, **models_args)
|
||||
|
||||
@classmethod
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
cls.cached_folder = (
|
||||
pretrained_model_name_or_path
|
||||
if os.path.isdir(pretrained_model_name_or_path)
|
||||
else snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
)
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
|
||||
super().to(torch_device, silence_dtype_warnings)
|
||||
|
||||
self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir)
|
||||
self.engine_dir = os.path.join(self.cached_folder, self.engine_dir)
|
||||
self.timing_cache = os.path.join(self.cached_folder, self.timing_cache)
|
||||
|
||||
# set device
|
||||
self.torch_device = self._execution_device
|
||||
logger.warning(f"Running inference on device: {self.torch_device}")
|
||||
|
||||
# load models
|
||||
self.__loadModels()
|
||||
|
||||
# build engines
|
||||
self.engine = build_engines(
|
||||
self.models,
|
||||
self.engine_dir,
|
||||
self.onnx_dir,
|
||||
self.onnx_opset,
|
||||
opt_image_height=self.image_height,
|
||||
opt_image_width=self.image_width,
|
||||
force_engine_rebuild=self.force_engine_rebuild,
|
||||
static_batch=self.build_static_batch,
|
||||
static_shape=not self.build_dynamic_shape,
|
||||
enable_preview=self.build_preview_features,
|
||||
timing_cache=self.timing_cache,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __encode_prompt(self, prompt, negative_prompt):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
"""
|
||||
# Tokenize prompt
|
||||
text_input_ids = (
|
||||
self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
.input_ids.type(torch.int32)
|
||||
.to(self.torch_device)
|
||||
)
|
||||
|
||||
text_input_ids_inp = device_view(text_input_ids)
|
||||
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
|
||||
text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[
|
||||
"text_embeddings"
|
||||
].clone()
|
||||
|
||||
# Tokenize negative prompt
|
||||
uncond_input_ids = (
|
||||
self.tokenizer(
|
||||
negative_prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
.input_ids.type(torch.int32)
|
||||
.to(self.torch_device)
|
||||
)
|
||||
uncond_input_ids_inp = device_view(uncond_input_ids)
|
||||
uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[
|
||||
"text_embeddings"
|
||||
]
|
||||
|
||||
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16)
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def __denoise_latent(
|
||||
self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None
|
||||
):
|
||||
if not isinstance(timesteps, torch.Tensor):
|
||||
timesteps = self.scheduler.timesteps
|
||||
for step_index, timestep in enumerate(timesteps):
|
||||
# Expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
|
||||
if isinstance(mask, torch.Tensor):
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
# Predict the noise residual
|
||||
timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep
|
||||
|
||||
sample_inp = device_view(latent_model_input)
|
||||
timestep_inp = device_view(timestep_float)
|
||||
embeddings_inp = device_view(text_embeddings)
|
||||
noise_pred = runEngine(
|
||||
self.engine["unet"],
|
||||
{"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp},
|
||||
self.stream,
|
||||
)["latent"]
|
||||
|
||||
# Perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample
|
||||
|
||||
latents = 1.0 / 0.18215 * latents
|
||||
return latents
|
||||
|
||||
def __decode_latent(self, latents):
|
||||
images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"]
|
||||
images = (images / 2 + 0.5).clamp(0, 1)
|
||||
return images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
def __loadResources(self, image_height, image_width, batch_size):
|
||||
self.stream = cuda.Stream()
|
||||
|
||||
# Allocate buffers for TensorRT engine bindings
|
||||
for model_name, obj in self.models.items():
|
||||
self.engine[model_name].allocate_buffers(
|
||||
shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
|
||||
"""
|
||||
self.generator = generator
|
||||
self.denoising_steps = num_inference_steps
|
||||
self.guidance_scale = guidance_scale
|
||||
|
||||
# Pre-compute latent input scales and linear multistep coefficients
|
||||
self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
|
||||
|
||||
# Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
prompt = [prompt]
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}")
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = [""] * batch_size
|
||||
|
||||
if negative_prompt is not None and isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt]
|
||||
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
|
||||
if batch_size > self.max_batch_size:
|
||||
raise ValueError(
|
||||
f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4"
|
||||
)
|
||||
|
||||
# load resources
|
||||
self.__loadResources(self.image_height, self.image_width, batch_size)
|
||||
|
||||
with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER):
|
||||
# CLIP text encoder
|
||||
text_embeddings = self.__encode_prompt(prompt, negative_prompt)
|
||||
|
||||
# Pre-initialize latents
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
self.image_height,
|
||||
self.image_width,
|
||||
torch.float32,
|
||||
self.torch_device,
|
||||
generator,
|
||||
)
|
||||
|
||||
# UNet denoiser
|
||||
latents = self.__denoise_latent(latents, text_embeddings)
|
||||
|
||||
# VAE decode latent
|
||||
images = self.__decode_latent(latents)
|
||||
|
||||
images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype)
|
||||
images = self.numpy_to_pil(images)
|
||||
return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -372,9 +372,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
num_channels_latents = self.decoder.config.in_channels
|
||||
height = self.decoder.config.sample_size
|
||||
width = self.decoder.config.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
@@ -425,9 +425,9 @@ class UnCLIPImageInterpolationPipeline(DiffusionPipeline):
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
channels = self.super_res_first.config.in_channels // 2
|
||||
height = self.super_res_first.config.sample_size
|
||||
width = self.super_res_first.config.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
|
||||
@@ -452,9 +452,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
num_channels_latents = self.decoder.config.in_channels
|
||||
height = self.decoder.config.sample_size
|
||||
width = self.decoder.config.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
@@ -505,9 +505,9 @@ class UnCLIPTextInterpolationPipeline(DiffusionPipeline):
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
channels = self.super_res_first.config.in_channels // 2
|
||||
height = self.super_res_first.config.sample_size
|
||||
width = self.super_res_first.config.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
|
||||
@@ -96,6 +96,29 @@ accelerate launch train_controlnet.py \
|
||||
--gradient_accumulation_steps=4
|
||||
```
|
||||
|
||||
## Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="path to save model"
|
||||
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_controlnet.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=fusing/fill50k \
|
||||
--resolution=512 \
|
||||
--learning_rate=1e-5 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--train_batch_size=4 \
|
||||
--mixed_precision="fp16" \
|
||||
--tracker_project_name="controlnet-demo" \
|
||||
--report_to=wandb
|
||||
```
|
||||
|
||||
## Example results
|
||||
|
||||
#### After 300 steps with batch size 8
|
||||
@@ -284,9 +307,9 @@ TPU_TYPE=v4-8
|
||||
VM_NAME=hg_flax
|
||||
|
||||
gcloud alpha compute tpus tpu-vm create $VM_NAME \
|
||||
--zone $ZONE \
|
||||
--accelerator-type $TPU_TYPE \
|
||||
--version tpu-vm-v4-base
|
||||
--zone $ZONE \
|
||||
--accelerator-type $TPU_TYPE \
|
||||
--version tpu-vm-v4-base
|
||||
|
||||
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
|
||||
```
|
||||
@@ -326,6 +349,7 @@ If you want to use Weights and Biases logging, you should also install `wandb` n
|
||||
pip install wandb
|
||||
```
|
||||
|
||||
|
||||
Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
|
||||
|
||||
```
|
||||
@@ -343,8 +367,8 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="control_out"
|
||||
export HUB_MODEL_ID="fill-circle-controlnet"
|
||||
export OUTPUT_DIR="runs/fill-circle-{timestamp}"
|
||||
export HUB_MODEL_ID="controlnet-fill-circle"
|
||||
```
|
||||
|
||||
And finally start the training
|
||||
@@ -363,32 +387,36 @@ python3 train_controlnet_flax.py \
|
||||
--revision="non-ema" \
|
||||
--from_pt \
|
||||
--report_to="wandb" \
|
||||
--max_train_steps=10000 \
|
||||
--tracker_project_name=$HUB_MODEL_ID \
|
||||
--num_train_epochs=11 \
|
||||
--push_to_hub \
|
||||
--hub_model_id=$HUB_MODEL_ID
|
||||
```
|
||||
|
||||
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
|
||||
|
||||
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:
|
||||
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)):
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="runs/uncanny-faces-{timestamp}"
|
||||
export HUB_MODEL_ID="controlnet-uncanny-faces"
|
||||
|
||||
python3 train_controlnet_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
|
||||
--streaming \
|
||||
--conditioning_image_column=spiga_seg \
|
||||
--image_column=image \
|
||||
--caption_column=image_caption \
|
||||
--resolution=512 \
|
||||
--max_train_samples 50 \
|
||||
--max_train_steps 5 \
|
||||
--learning_rate=1e-5 \
|
||||
--validation_steps=2 \
|
||||
--train_batch_size=1 \
|
||||
--revision="flax" \
|
||||
--report_to="wandb"
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
|
||||
--streaming \
|
||||
--conditioning_image_column=spiga_seg \
|
||||
--image_column=image \
|
||||
--caption_column=image_caption \
|
||||
--resolution=512 \
|
||||
--max_train_samples 100000 \
|
||||
--learning_rate=1e-5 \
|
||||
--train_batch_size=1 \
|
||||
--revision="flax" \
|
||||
--report_to="wandb" \
|
||||
--tracker_project_name=$HUB_MODEL_ID
|
||||
```
|
||||
|
||||
Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
|
||||
@@ -400,16 +428,35 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream
|
||||
When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
|
||||
|
||||
```bash
|
||||
--checkpointing_steps=500
|
||||
--checkpointing_steps=500
|
||||
```
|
||||
This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
|
||||
|
||||
You can then start your training from this saved checkpoint with
|
||||
|
||||
```bash
|
||||
--controlnet_model_name_or_path="./control_out/500"
|
||||
--controlnet_model_name_or_path="./control_out/500"
|
||||
```
|
||||
|
||||
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
|
||||
|
||||
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
|
||||
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
|
||||
|
||||
You can **profile your code** with:
|
||||
|
||||
```bash
|
||||
--profile_steps==5
|
||||
```
|
||||
|
||||
Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin:
|
||||
|
||||
```bash
|
||||
pip install tensorflow tensorboard-plugin-profile
|
||||
tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
|
||||
```
|
||||
|
||||
The profile can then be inspected at http://localhost:6006/#profile
|
||||
|
||||
Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).
|
||||
|
||||
Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).
|
||||
|
||||
@@ -55,7 +55,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -525,6 +525,11 @@ def parse_args(input_args=None):
|
||||
" or the same number of `--validation_prompt`s and `--validation_image`s"
|
||||
)
|
||||
|
||||
if args.resolution % 8 != 0:
|
||||
raise ValueError(
|
||||
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -607,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator):
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
@@ -615,6 +621,7 @@ def make_train_dataset(args, tokenizer, accelerator):
|
||||
conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import jax
|
||||
@@ -58,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -75,20 +76,11 @@ def image_grid(imgs, rows, cols):
|
||||
return grid
|
||||
|
||||
|
||||
def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype):
|
||||
logger.info("Running validation... ")
|
||||
def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):
|
||||
logger.info("Running validation...")
|
||||
|
||||
pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
from_pt=args.from_pt,
|
||||
)
|
||||
params = jax_utils.replicate(params)
|
||||
params["controlnet"] = controlnet_params
|
||||
pipeline_params = pipeline_params.copy()
|
||||
pipeline_params["controlnet"] = controlnet_params
|
||||
|
||||
num_samples = jax.device_count()
|
||||
prng_seed = jax.random.split(rng, jax.device_count())
|
||||
@@ -120,7 +112,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
|
||||
images = pipeline(
|
||||
prompt_ids=prompt_ids,
|
||||
image=processed_image,
|
||||
params=params,
|
||||
params=pipeline_params,
|
||||
prng_seed=prng_seed,
|
||||
num_inference_steps=50,
|
||||
jit=True,
|
||||
@@ -175,6 +167,7 @@ tags:
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- controlnet
|
||||
- jax-diffusers-event
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
@@ -220,6 +213,28 @@ def parse_args():
|
||||
default=None,
|
||||
help="Revision of controlnet model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="How many training steps to profile in the beginning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_validation",
|
||||
action="store_true",
|
||||
help="Whether to profile the (last) validation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile_memory",
|
||||
action="store_true",
|
||||
help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ccache",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Enables compilation cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--controlnet_from_pt",
|
||||
action="store_true",
|
||||
@@ -234,8 +249,9 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="controlnet-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
default="runs/{timestamp}",
|
||||
help="The output directory where the model predictions and checkpoints will be written. "
|
||||
"Can contain placeholders: {timestamp}.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
@@ -317,15 +333,6 @@ def parse_args():
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_steps",
|
||||
type=int,
|
||||
@@ -459,6 +466,8 @@ def parse_args():
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S"))
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -783,6 +792,17 @@ def main():
|
||||
]:
|
||||
controlnet_params[key] = unet_params[key]
|
||||
|
||||
pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
tokenizer=tokenizer,
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
dtype=weight_dtype,
|
||||
revision=args.revision,
|
||||
from_pt=args.from_pt,
|
||||
)
|
||||
pipeline_params = jax_utils.replicate(pipeline_params)
|
||||
|
||||
# Optimization
|
||||
if args.scale_lr:
|
||||
args.learning_rate = args.learning_rate * total_train_batch_size
|
||||
@@ -952,6 +972,11 @@ def main():
|
||||
metrics = {"loss": loss}
|
||||
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
||||
|
||||
def l2(xs):
|
||||
return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
|
||||
|
||||
metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad))
|
||||
|
||||
return new_state, metrics, new_train_rng
|
||||
|
||||
# Create parallel version of the train step
|
||||
@@ -983,32 +1008,38 @@ def main():
|
||||
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
|
||||
logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")
|
||||
|
||||
if jax.process_index() == 0:
|
||||
if jax.process_index() == 0 and args.report_to == "wandb":
|
||||
wandb.define_metric("*", step_metric="train/step")
|
||||
wandb.define_metric("train/step", step_metric="walltime")
|
||||
wandb.config.update(
|
||||
{
|
||||
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
|
||||
"total_train_batch_size": total_train_batch_size,
|
||||
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
|
||||
"num_devices": jax.device_count(),
|
||||
"controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
|
||||
}
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
global_step = step0 = 0
|
||||
epochs = tqdm(
|
||||
range(args.num_train_epochs),
|
||||
desc="Epoch ... ",
|
||||
position=0,
|
||||
disable=jax.process_index() > 0,
|
||||
)
|
||||
if args.profile_memory:
|
||||
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof"))
|
||||
t00 = t0 = time.monotonic()
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
|
||||
train_metrics = []
|
||||
train_metric = None
|
||||
|
||||
steps_per_epoch = (
|
||||
args.max_train_samples // total_train_batch_size
|
||||
if args.streaming
|
||||
if args.streaming or args.max_train_samples
|
||||
else len(train_dataset) // total_train_batch_size
|
||||
)
|
||||
train_step_progress_bar = tqdm(
|
||||
@@ -1020,10 +1051,18 @@ def main():
|
||||
)
|
||||
# train
|
||||
for batch in train_dataloader:
|
||||
if args.profile_steps and global_step == 1:
|
||||
train_metric["loss"].block_until_ready()
|
||||
jax.profiler.start_trace(args.output_dir)
|
||||
if args.profile_steps and global_step == 1 + args.profile_steps:
|
||||
train_metric["loss"].block_until_ready()
|
||||
jax.profiler.stop_trace()
|
||||
|
||||
batch = shard(batch)
|
||||
state, train_metric, train_rngs = p_train_step(
|
||||
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
|
||||
)
|
||||
with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
|
||||
state, train_metric, train_rngs = p_train_step(
|
||||
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
|
||||
)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_step_progress_bar.update(1)
|
||||
@@ -1037,17 +1076,25 @@ def main():
|
||||
and global_step % args.validation_steps == 0
|
||||
and jax.process_index() == 0
|
||||
):
|
||||
_ = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
_ = log_validation(
|
||||
pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
|
||||
)
|
||||
|
||||
if global_step % args.logging_steps == 0 and jax.process_index() == 0:
|
||||
if args.report_to == "wandb":
|
||||
train_metrics = jax_utils.unreplicate(train_metrics)
|
||||
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
|
||||
wandb.log(
|
||||
{
|
||||
"walltime": time.monotonic() - t00,
|
||||
"train/step": global_step,
|
||||
"train/epoch": epoch,
|
||||
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
|
||||
"train/epoch": global_step / dataset_length,
|
||||
"train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
|
||||
**{f"train/{k}": v for k, v in train_metrics.items()},
|
||||
}
|
||||
)
|
||||
t0, step0 = time.monotonic(), global_step
|
||||
train_metrics = []
|
||||
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
|
||||
controlnet.save_pretrained(
|
||||
f"{args.output_dir}/{global_step}",
|
||||
@@ -1058,10 +1105,16 @@ def main():
|
||||
train_step_progress_bar.close()
|
||||
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
# Final validation & store model.
|
||||
if jax.process_index() == 0:
|
||||
if args.validation_prompt is not None:
|
||||
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
if args.profile_validation:
|
||||
jax.profiler.start_trace(args.output_dir)
|
||||
image_logs = log_validation(
|
||||
pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype
|
||||
)
|
||||
if args.profile_validation:
|
||||
jax.profiler.stop_trace()
|
||||
else:
|
||||
image_logs = None
|
||||
|
||||
@@ -1084,6 +1137,10 @@ def main():
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
if args.profile_memory:
|
||||
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof"))
|
||||
logger.info("Finished training.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,280 @@
|
||||
# Custom Diffusion training example
|
||||
|
||||
[Custom Diffusion](https://arxiv.org/abs/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject.
|
||||
The `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install clip-retrieval
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
### Cat example 😺
|
||||
|
||||
Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.
|
||||
|
||||
We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
|
||||
The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
|
||||
```
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="./data/cat"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>"
|
||||
```
|
||||
|
||||
**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
|
||||
|
||||
To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps:
|
||||
|
||||
* Install `wandb`: `pip install wandb`.
|
||||
* Authorize: `wandb login`.
|
||||
* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments:
|
||||
* `num_validation_images`
|
||||
* `validation_steps`
|
||||
|
||||
Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>" \
|
||||
--validation_prompt="<new1> cat sitting in a bucket" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details.
|
||||
|
||||
If you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat).
|
||||
|
||||
### Training on multiple concepts 🐱🪵
|
||||
|
||||
Provide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py).
|
||||
|
||||
To collect the real images run this command for each concept in the json file.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
|
||||
```
|
||||
|
||||
And then we're ready to start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--concepts_list=./concept_list.json \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--num_class_images=200 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>+<new2>"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details.
|
||||
|
||||
### Training on human faces
|
||||
|
||||
For fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images.
|
||||
|
||||
To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
|
||||
```
|
||||
|
||||
Then start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="path-to-images"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_person/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="person" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> person" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=1000 \
|
||||
--scale_lr --hflip --noaug \
|
||||
--freeze_model crossattn \
|
||||
--modifier_token "<new1>" \
|
||||
--enable_xformers_memory_efficient_attention
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \<new1\> in above example) in your prompt.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipe.unet.load_attn_procs(
|
||||
"path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin"
|
||||
)
|
||||
pipe.load_textual_inversion("path-to-save-model", weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
It's possible to directly load these parameters from a Hub repository:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(
|
||||
"cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
Here is an example of performing inference with multiple concepts:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(
|
||||
"cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new2>.bin")
|
||||
|
||||
image = pipe(
|
||||
"the <new1> cat sculpture in the style of a <new2> wooden pot",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("multi-subject.png")
|
||||
```
|
||||
|
||||
Here, `cat` and `wooden pot` refer to the multiple concepts.
|
||||
|
||||
### Inference from a training checkpoint
|
||||
|
||||
You can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument.
|
||||
|
||||
TODO.
|
||||
|
||||
## Set grads to none
|
||||
To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
|
||||
|
||||
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
|
||||
|
||||
## Experimental results
|
||||
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail.
|
||||
@@ -0,0 +1,6 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright 2023 Custom Diffusion authors. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from clip_retrieval.clip_client import ClipClient
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def retrieve(class_prompt, class_data_dir, num_class_images):
|
||||
factor = 1.5
|
||||
num_images = int(factor * num_class_images)
|
||||
client = ClipClient(
|
||||
url="https://knn.laion.ai/knn-service", indice_name="laion_400m", num_images=num_images, aesthetic_weight=0.1
|
||||
)
|
||||
|
||||
os.makedirs(f"{class_data_dir}/images", exist_ok=True)
|
||||
if len(list(Path(f"{class_data_dir}/images").iterdir())) >= num_class_images:
|
||||
return
|
||||
|
||||
while True:
|
||||
class_images = client.query(text=class_prompt)
|
||||
if len(class_images) >= factor * num_class_images or num_images > 1e4:
|
||||
break
|
||||
else:
|
||||
num_images = int(factor * num_images)
|
||||
client = ClipClient(
|
||||
url="https://knn.laion.ai/knn-service",
|
||||
indice_name="laion_400m",
|
||||
num_images=num_images,
|
||||
aesthetic_weight=0.1,
|
||||
)
|
||||
|
||||
count = 0
|
||||
total = 0
|
||||
pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
|
||||
|
||||
with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
|
||||
f"{class_data_dir}/images.txt", "w"
|
||||
) as f3:
|
||||
while total < num_class_images:
|
||||
images = class_images[count]
|
||||
count += 1
|
||||
try:
|
||||
img = requests.get(images["url"])
|
||||
if img.status_code == 200:
|
||||
_ = Image.open(BytesIO(img.content))
|
||||
with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f:
|
||||
f.write(img.content)
|
||||
f1.write(images["caption"] + "\n")
|
||||
f2.write(images["url"] + "\n")
|
||||
f3.write(f"{class_data_dir}/images/{total}.jpg" + "\n")
|
||||
total += 1
|
||||
pbar.update(1)
|
||||
else:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
return
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("", add_help=False)
|
||||
parser.add_argument("--class_prompt", help="text prompt to retrieve images", required=True, type=str)
|
||||
parser.add_argument("--class_data_dir", help="path to save images", required=True, type=str)
|
||||
parser.add_argument("--num_class_images", help="number of images to download", default=200, type=int)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
retrieve(args.class_prompt, args.class_data_dir, args.num_class_images)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -45,15 +45,28 @@ write_basic_config()
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
And launch the training using
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
And launch the training using:
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_dreambooth.py \
|
||||
@@ -77,7 +90,7 @@ According to the paper, it's recommended to generate `num_epochs * num_samples`
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -108,7 +121,7 @@ To install `bitandbytes` please refer to this [readme](https://github.com/TimDet
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -141,7 +154,7 @@ It is possible to run dreambooth on a 12GB GPU by using the following optimizati
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -185,7 +198,7 @@ does not seem to be compatible with DeepSpeed at the moment.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -217,7 +230,7 @@ ___Note: Training text encoder requires more memory, with this option the traini
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -300,7 +313,7 @@ Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
```
|
||||
|
||||
@@ -342,6 +355,12 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr
|
||||
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
|
||||
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
|
||||
|
||||
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we
|
||||
enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).
|
||||
|
||||
With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
|
||||
|
||||
|
||||
### Inference
|
||||
|
||||
After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
|
||||
@@ -386,7 +405,7 @@ pip install -U -r requirements_flax.txt
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
python train_dreambooth_flax.py \
|
||||
@@ -405,7 +424,7 @@ python train_dreambooth_flax.py \
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
@@ -429,7 +448,7 @@ python train_dreambooth_flax.py \
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export INSTANCE_DIR="dog"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -43,22 +44,23 @@ from diffusers import (
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
@@ -83,6 +85,8 @@ inference: true
|
||||
|
||||
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
|
||||
{img_str}
|
||||
|
||||
LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
@@ -219,6 +223,11 @@ def parse_args(input_args=None):
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
@@ -547,7 +556,13 @@ def main(args):
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
# TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
||||
raise ValueError(
|
||||
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -691,7 +706,7 @@ def main(args):
|
||||
# => 32 layers
|
||||
|
||||
# Set correct lora layers
|
||||
lora_attn_procs = {}
|
||||
unet_lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
@@ -703,12 +718,33 @@ def main(args):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
unet_lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
accelerator.register_for_checkpointing(unet_lora_layers)
|
||||
|
||||
accelerator.register_for_checkpointing(lora_layers)
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
|
||||
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
|
||||
text_encoder_lora_layers = None
|
||||
if args.train_text_encoder:
|
||||
text_lora_attn_procs = {}
|
||||
for name, module in text_encoder.named_modules():
|
||||
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
|
||||
text_lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=module.out_features, cross_attention_dim=None
|
||||
)
|
||||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
|
||||
temp_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, text_encoder=text_encoder
|
||||
)
|
||||
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
|
||||
text_encoder = temp_pipeline.text_encoder
|
||||
accelerator.register_for_checkpointing(unet_lora_layers)
|
||||
del temp_pipeline
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
@@ -739,8 +775,13 @@ def main(args):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# Optimizer creation
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet_lora_layers.parameters()
|
||||
)
|
||||
optimizer = optimizer_class(
|
||||
lora_layers.parameters(),
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
@@ -784,9 +825,14 @@ def main(args):
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -845,6 +891,8 @@ def main(args):
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
@@ -900,7 +948,11 @@ def main(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = lora_layers.parameters()
|
||||
params_to_clip = (
|
||||
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet_lora_layers.parameters()
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -914,7 +966,14 @@ def main(args):
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
# We combine the text encoder and UNet LoRA parameters with a simple
|
||||
# custom logic. `accelerator.save_state()` won't know that. So,
|
||||
# use `LoraLoaderMixin.save_lora_weights()`.
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
@@ -970,7 +1029,12 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = unet.to(torch.float32)
|
||||
unet.save_attn_procs(args.output_dir)
|
||||
text_encoder = text_encoder.to(torch.float32)
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
@@ -981,7 +1045,7 @@ def main(args):
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
# load attention processors
|
||||
pipeline.unet.load_attn_procs(args.output_dir)
|
||||
pipeline.load_attn_procs(args.output_dir)
|
||||
|
||||
# run inference
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
@@ -1010,6 +1074,7 @@ def main(args):
|
||||
repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
prompt=args.instance_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
|
||||
@@ -113,6 +113,27 @@ accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
|
||||
|
||||
***Note: In the original paper, the authors observed that even when the model is trained with an image resolution of 256x256, it generalizes well to bigger resolutions such as 512x512. This is likely because of the larger dataset they used during training.***
|
||||
|
||||
## Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py \
|
||||
--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \
|
||||
--dataset_name=sayakpaul/instructpix2pix-1000-samples \
|
||||
--use_ema \
|
||||
--enable_xformers_memory_efficient_attention \
|
||||
--resolution=512 --random_flip \
|
||||
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
|
||||
--max_train_steps=15000 \
|
||||
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
|
||||
--learning_rate=5e-05 --lr_warmup_steps=0 \
|
||||
--conditioning_dropout_prob=0.05 \
|
||||
--mixed_precision=fp16 \
|
||||
--seed=42
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once training is complete, we can perform inference:
|
||||
|
||||
@@ -51,7 +51,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# Distillation for quantization on Textual Inversion models to personalize text2image
|
||||
|
||||
[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images._By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images_
|
||||
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
We have enabled distillation for quantization in `textual_inversion.py` to do quantization aware training as well as distillation on the model generated by Textual Inversion method.
|
||||
|
||||
## Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Prepare Datasets
|
||||
|
||||
One picture which is from the huggingface datasets [sd-concepts-library/dicoo2](https://huggingface.co/sd-concepts-library/dicoo2) is needed, and save it to the `./dicoo` directory. The picture is shown below:
|
||||
|
||||
<a href="https://huggingface.co/sd-concepts-library/dicoo2/blob/main/concept_images/1.jpeg">
|
||||
<img src="https://huggingface.co/sd-concepts-library/dicoo2/resolve/main/concept_images/1.jpeg" width = "300" height="300">
|
||||
</a>
|
||||
|
||||
## Get a FP32 Textual Inversion model
|
||||
|
||||
Use the following command to fine-tune the Stable Diffusion model on the above dataset to obtain the FP32 Textual Inversion model.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export DATA_DIR="./dicoo"
|
||||
|
||||
accelerate launch textual_inversion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_data_dir=$DATA_DIR \
|
||||
--learnable_property="object" \
|
||||
--placeholder_token="<dicoo>" --initializer_token="toy" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--max_train_steps=3000 \
|
||||
--learning_rate=5.0e-04 --scale_lr \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="dicoo_model"
|
||||
```
|
||||
|
||||
## Do distillation for quantization
|
||||
|
||||
Distillation for quantization is a method that combines [intermediate layer knowledge distillation](https://github.com/intel/neural-compressor/blob/master/docs/source/distillation.md#intermediate-layer-knowledge-distillation) and [quantization aware training](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization.md#quantization-aware-training) in the same training process to improve the performance of the quantized model. Provided a FP32 model, the distillation for quantization approach will take this model itself as the teacher model and transfer the knowledges of the specified layers to the student model, i.e. quantized version of the FP32 model, during the quantization aware training process.
|
||||
|
||||
Once you have the FP32 Textual Inversion model, the following command will take the FP32 Textual Inversion model as input to do distillation for quantization and generate the INT8 Textual Inversion model.
|
||||
|
||||
```bash
|
||||
export FP32_MODEL_NAME="./dicoo_model"
|
||||
export DATA_DIR="./dicoo"
|
||||
|
||||
accelerate launch textual_inversion.py \
|
||||
--pretrained_model_name_or_path=$FP32_MODEL_NAME \
|
||||
--train_data_dir=$DATA_DIR \
|
||||
--use_ema --learnable_property="object" \
|
||||
--placeholder_token="<dicoo>" --initializer_token="toy" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--max_train_steps=300 \
|
||||
--learning_rate=5.0e-04 --max_grad_norm=3 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--output_dir="int8_model" \
|
||||
--do_quantization --do_distillation --verify_loading
|
||||
```
|
||||
|
||||
After the distillation for quantization process, the quantized UNet would be 4 times smaller (3279MB -> 827MB).
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a INT8 model with the above command, the inference can be done simply using the `text2images.py` script. Make sure to include the `placeholder_token` in your prompt.
|
||||
|
||||
```bash
|
||||
export INT8_MODEL_NAME="./int8_model"
|
||||
|
||||
python text2images.py \
|
||||
--pretrained_model_name_or_path=$INT8_MODEL_NAME \
|
||||
--caption "a lovely <dicoo> in red dress and hat, in the snowly and brightly night, with many brighly buildings." \
|
||||
--images_num 4
|
||||
```
|
||||
|
||||
Here is the comparison of images generated by the FP32 model (left) and INT8 model (right) respectively:
|
||||
|
||||
<p float="left">
|
||||
<img src="https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/FP32.png" width = "300" height = "300" alt="FP32" align=center />
|
||||
<img src="https://huggingface.co/datasets/Intel/textual_inversion_dicoo_dfq/resolve/main/INT8.png" width = "300" height = "300" alt="INT8" align=center />
|
||||
</p>
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.25.0
|
||||
ftfy
|
||||
tensorboard
|
||||
modelcards
|
||||
neural-compressor
|
||||
@@ -0,0 +1,112 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from neural_compressor.utils.pytorch import load
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--caption",
|
||||
type=str,
|
||||
default="robotic cat with wings",
|
||||
help="Text used to generate images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--images_num",
|
||||
type=int,
|
||||
default=4,
|
||||
help="How much images to generate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Seed for random process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ci",
|
||||
"--cuda_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="cuda_id.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def image_grid(imgs, rows, cols):
|
||||
if not len(imgs) == rows * cols:
|
||||
raise ValueError("The specified number of rows and columns are not correct.")
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
grid_w, grid_h = grid.size
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
return grid
|
||||
|
||||
|
||||
def generate_images(
|
||||
pipeline,
|
||||
prompt="robotic cat with wings",
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
num_images_per_prompt=1,
|
||||
seed=42,
|
||||
):
|
||||
generator = torch.Generator(pipeline.device).manual_seed(seed)
|
||||
images = pipeline(
|
||||
prompt,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
).images
|
||||
_rows = int(math.sqrt(num_images_per_prompt))
|
||||
grid = image_grid(images, rows=_rows, cols=num_images_per_prompt // _rows)
|
||||
return grid, images
|
||||
|
||||
|
||||
args = parse_args()
|
||||
# Load models and create wrapper for stable diffusion
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer
|
||||
)
|
||||
pipeline.safety_checker = lambda images, clip_input: (images, False)
|
||||
if os.path.exists(os.path.join(args.pretrained_model_name_or_path, "best_model.pt")):
|
||||
unet = load(args.pretrained_model_name_or_path, model=unet)
|
||||
unet.eval()
|
||||
setattr(pipeline, "unet", unet)
|
||||
else:
|
||||
unet = unet.to(torch.device("cuda", args.cuda_id))
|
||||
pipeline = pipeline.to(unet.device)
|
||||
grid, images = generate_images(pipeline, prompt=args.caption, num_images_per_prompt=args.images_num, seed=args.seed)
|
||||
grid.save(os.path.join(args.pretrained_model_name_or_path, "{}.png".format("_".join(args.caption.split()))))
|
||||
dirname = os.path.join(args.pretrained_model_name_or_path, "_".join(args.caption.split()))
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
for idx, image in enumerate(images):
|
||||
image.save(os.path.join(dirname, "{}.png".format(idx + 1)))
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,6 +23,7 @@ import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
from diffusers import DiffusionPipeline, UNet2DConditionModel
|
||||
@@ -221,6 +222,92 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
|
||||
|
||||
def test_dreambooth_lora(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` in their names.
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_dreambooth_lora_with_text_encoder(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--train_text_encoder
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
|
||||
|
||||
# the names of the keys of the state dict should either start with `unet`
|
||||
# or `text_encoder`.
|
||||
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
||||
keys = lora_state_dict.keys()
|
||||
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
|
||||
self.assertTrue(is_correct_naming)
|
||||
|
||||
def test_custom_diffusion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/custom_diffusion/train_custom_diffusion.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt <new1>
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 1.0e-05
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--modifier_token <new1>
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "<new1>.bin")))
|
||||
|
||||
def test_text_to_image(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -111,6 +111,31 @@ image = pipe(prompt="yoda").images[0]
|
||||
image.save("yoda-pokemon.png")
|
||||
```
|
||||
|
||||
#### Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export dataset_name="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$dataset_name \
|
||||
--use_ema \
|
||||
--resolution=512 --center_crop --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--output_dir="sd-pokemon-model"
|
||||
```
|
||||
|
||||
|
||||
#### Training with Min-SNR weighting
|
||||
|
||||
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -64,8 +64,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
safety_checker=None,
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -340,11 +340,10 @@ def main():
|
||||
|
||||
return examples
|
||||
|
||||
if jax.process_index() == 0:
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
|
||||
@@ -47,7 +47,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -39,29 +39,31 @@ accelerate config
|
||||
|
||||
### Cat toy example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
Run the following command to authenticate your token
|
||||
First, let's login so that we can upload the checkpoint to the Hub during training:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
If you have already cloned the repo, then you won't need to go through these steps.
|
||||
Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
|
||||
|
||||
<br>
|
||||
Let's first download it locally:
|
||||
|
||||
Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.
|
||||
```py
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
And launch the training using
|
||||
local_dir = "./cat"
|
||||
snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes")
|
||||
```
|
||||
|
||||
This will be our training data.
|
||||
Now we can launch the training using
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export DATA_DIR="path-to-dir-containing-images"
|
||||
export DATA_DIR="./cat"
|
||||
|
||||
accelerate launch textual_inversion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
@@ -75,6 +77,7 @@ accelerate launch textual_inversion.py \
|
||||
--learning_rate=5.0e-04 --scale_lr \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--push_to_hub \
|
||||
--output_dir="textual_inversion_cat"
|
||||
```
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
## Training examples
|
||||
## Training an unconditional diffusion model
|
||||
|
||||
Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
|
||||
|
||||
@@ -76,6 +76,27 @@ A full training run takes 2 hours on 4xV100 GPUs.
|
||||
|
||||
<img src="https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png" width="700" />
|
||||
|
||||
### Training with multiple GPUs
|
||||
|
||||
`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch)
|
||||
for running distributed training with `accelerate`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
|
||||
--dataset_name="huggan/pokemon" \
|
||||
--resolution=64 --center_crop --random_flip \
|
||||
--output_dir="ddpm-ema-pokemon-64" \
|
||||
--train_batch_size=16 \
|
||||
--num_epochs=100 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--use_ema \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_warmup_steps=500 \
|
||||
--mixed_precision="fp16" \
|
||||
--logger="wandb"
|
||||
```
|
||||
|
||||
To be able to use Weights and Biases (`wandb`) as a logger you need to install the library: `pip install wandb`.
|
||||
|
||||
### Using your own data
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.15.0")
|
||||
check_min_version("0.16.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ _deps = [
|
||||
"filelock",
|
||||
"flax>=0.4.1",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.13.2",
|
||||
"huggingface-hub==0.14.0rc1",
|
||||
"requests-mock==1.10.0",
|
||||
"importlib_metadata",
|
||||
"isort>=5.5.4",
|
||||
@@ -226,7 +226,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.15.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.16.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.15.0"
|
||||
__version__ = "0.16.0.dev0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import (
|
||||
@@ -109,7 +109,6 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .loaders import TextualInversionLoaderMixin
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
|
||||
@@ -118,6 +118,24 @@ class ConfigMixin:
|
||||
|
||||
self._internal_dict = FrozenDict(internal_dict)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
||||
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
|
||||
|
||||
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
|
||||
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
||||
"""
|
||||
|
||||
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
||||
is_attribute = name in self.__dict__
|
||||
|
||||
if is_in_config and not is_attribute:
|
||||
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
|
||||
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
|
||||
return self._internal_dict[name]
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
|
||||
+259
-9
@@ -13,11 +13,17 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .models.attention_processor import LoRAAttnProcessor
|
||||
from .models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
LoRAAttnProcessor,
|
||||
)
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
@@ -46,6 +52,9 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
TEXT_INVERSION_NAME = "learned_embeds.bin"
|
||||
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
||||
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
|
||||
class AttnProcsLayers(torch.nn.Module):
|
||||
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
||||
@@ -213,6 +222,7 @@ class UNet2DConditionLoadersMixin:
|
||||
attn_processors = {}
|
||||
|
||||
is_lora = all("lora" in k for k in state_dict.keys())
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
|
||||
if is_lora:
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
@@ -229,9 +239,38 @@ class UNet2DConditionLoadersMixin:
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
elif is_custom_diffusion:
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
if len(value) == 0:
|
||||
custom_diffusion_grouped_dict[key] = {}
|
||||
else:
|
||||
if "to_out" in key:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
else:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
||||
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
for key, value_dict in custom_diffusion_grouped_dict.items():
|
||||
if len(value_dict) == 0:
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
||||
)
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
||||
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
||||
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=True,
|
||||
train_q_out=train_q_out,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
else:
|
||||
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
|
||||
raise ValueError(
|
||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
||||
@@ -285,16 +324,31 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
model_to_save = AttnProcsLayers(self.attn_processors)
|
||||
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
is_custom_diffusion = any(
|
||||
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
||||
for (_, x) in self.attn_processors.items()
|
||||
)
|
||||
if is_custom_diffusion:
|
||||
model_to_save = AttnProcsLayers(
|
||||
{
|
||||
y: x
|
||||
for (y, x) in self.attn_processors.items()
|
||||
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
||||
}
|
||||
)
|
||||
state_dict = model_to_save.state_dict()
|
||||
for name, attn in self.attn_processors.items():
|
||||
if len(attn.state_dict()) == 0:
|
||||
state_dict[name] = {}
|
||||
else:
|
||||
model_to_save = AttnProcsLayers(self.attn_processors)
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
@@ -431,6 +485,7 @@ class TextualInversionLoaderMixin:
|
||||
Example:
|
||||
|
||||
To load a textual inversion embedding vector in `diffusers` format:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
@@ -463,6 +518,7 @@ class TextualInversionLoaderMixin:
|
||||
image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
image.save("character.png")
|
||||
```
|
||||
|
||||
"""
|
||||
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
|
||||
raise ValueError(
|
||||
@@ -792,7 +848,7 @@ class LoraLoaderMixin:
|
||||
"""
|
||||
# Loop over the original attention modules.
|
||||
for name, _ in self.text_encoder.named_modules():
|
||||
if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
|
||||
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
|
||||
# Retrieve the module and its corresponding LoRA processor.
|
||||
module = self.text_encoder.get_submodule(name)
|
||||
# Construct a new function that performs the LoRA merging. We will monkey patch
|
||||
@@ -1051,3 +1107,197 @@ class LoraLoaderMixin:
|
||||
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
|
||||
class FromCkptMixin:
|
||||
"""This helper class allows to directly load .ckpt stable diffusion file_extension
|
||||
into the respective classes."""
|
||||
|
||||
@classmethod
|
||||
def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format.
|
||||
|
||||
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the .ckpt file on the Hub. Should be in the format
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>"`
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
will be automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
use_safetensors (`bool`, *optional* ):
|
||||
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
|
||||
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
|
||||
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults
|
||||
to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
||||
inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted. This is necessary when running stable
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
|
||||
Base. Use 768 for Stable Diffusion v2.
|
||||
prediction_type (`str`, *optional*):
|
||||
The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
|
||||
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
|
||||
num_in_channels (`int`, *optional*, defaults to None):
|
||||
The number of input channels. If `None`, it will be automatically inferred.
|
||||
scheduler_type (`str`, *optional*, defaults to 'pndm'):
|
||||
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
||||
"ddim"]`.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
||||
Whether to load the safety checker or not. Defaults to `True`.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
||||
`__init__` method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
... )
|
||||
|
||||
>>> # Download pipeline from local file
|
||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
|
||||
|
||||
>>> # Enable float16 and move to GPU
|
||||
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
||||
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
"""
|
||||
# import here to avoid circular dependency
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", 512)
|
||||
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
||||
|
||||
pipeline_name = cls.__name__
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
from_safetensors = file_extension == "safetensors"
|
||||
|
||||
if from_safetensors and use_safetensors is True:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# TODO: For now we only support stable diffusion
|
||||
stable_unclip = None
|
||||
controlnet = False
|
||||
|
||||
if pipeline_name == "StableDiffusionControlNetPipeline":
|
||||
model_type = "FrozenCLIPEmbedder"
|
||||
controlnet = True
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
model_type = "FrozenCLIPEmbedder"
|
||||
elif pipeline_name == "StableUnCLIPPipeline":
|
||||
model_type == "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "txt2img"
|
||||
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
||||
model_type == "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "img2img"
|
||||
elif pipeline_name == "PaintByExamplePipeline":
|
||||
model_type == "PaintByExample"
|
||||
elif pipeline_name == "LDMTextToImagePipeline":
|
||||
model_type == "LDMTextToImage"
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
# remove huggingface url
|
||||
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = str(Path().joinpath(*ckpt_path.parts[:2]))
|
||||
file_path = str(Path().joinpath(*ckpt_path.parts[2:]))
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
pipeline_class=cls,
|
||||
model_type=model_type,
|
||||
stable_unclip=stable_unclip,
|
||||
controlnet=controlnet,
|
||||
from_safetensors=from_safetensors,
|
||||
extract_ema=extract_ema,
|
||||
image_size=image_size,
|
||||
scheduler_type=scheduler_type,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=load_safety_checker,
|
||||
prediction_type=prediction_type,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(torch_dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -149,6 +149,9 @@ class Attention(nn.Module):
|
||||
is_lora = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
|
||||
)
|
||||
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
||||
)
|
||||
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if self.added_kv_proj_dim is not None:
|
||||
@@ -192,6 +195,17 @@ class Attention(nn.Module):
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
||||
elif is_custom_diffusion:
|
||||
processor = CustomDiffusionXFormersAttnProcessor(
|
||||
train_kv=self.processor.train_kv,
|
||||
train_q_out=self.processor.train_q_out,
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
attention_op=attention_op,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_custom_diffusion"):
|
||||
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
||||
else:
|
||||
processor = XFormersAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
@@ -203,6 +217,16 @@ class Attention(nn.Module):
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
||||
elif is_custom_diffusion:
|
||||
processor = CustomDiffusionAttnProcessor(
|
||||
train_kv=self.processor.train_kv,
|
||||
train_q_out=self.processor.train_q_out,
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_custom_diffusion"):
|
||||
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
||||
else:
|
||||
processor = AttnProcessor()
|
||||
|
||||
@@ -459,6 +483,84 @@ class LoRAAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionAttnProcessor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
train_kv=True,
|
||||
train_q_out=True,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.train_kv = train_kv
|
||||
self.train_q_out = train_q_out
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
# `_custom_diffusion` id for easy serialization and loading.
|
||||
if self.train_kv:
|
||||
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
if self.train_q_out:
|
||||
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.to_out_custom_diffusion = nn.ModuleList([])
|
||||
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
||||
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
if self.train_q_out:
|
||||
query = self.to_q_custom_diffusion(hidden_states)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
crossattn = False
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
crossattn = True
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
if self.train_kv:
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
||||
else:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if crossattn:
|
||||
detach = torch.ones_like(key)
|
||||
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
||||
key = detach * key + (1 - detach) * key.detach()
|
||||
value = detach * value + (1 - detach) * value.detach()
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if self.train_q_out:
|
||||
# linear proj
|
||||
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
||||
else:
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnAddedKVProcessor:
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
@@ -699,6 +801,91 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
train_kv=True,
|
||||
train_q_out=False,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
dropout=0.0,
|
||||
attention_op: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.train_kv = train_kv
|
||||
self.train_q_out = train_q_out
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.attention_op = attention_op
|
||||
|
||||
# `_custom_diffusion` id for easy serialization and loading.
|
||||
if self.train_kv:
|
||||
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
if self.train_q_out:
|
||||
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.to_out_custom_diffusion = nn.ModuleList([])
|
||||
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
||||
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if self.train_q_out:
|
||||
query = self.to_q_custom_diffusion(hidden_states)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
crossattn = False
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
crossattn = True
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
if self.train_kv:
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
||||
else:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if crossattn:
|
||||
detach = torch.ones_like(key)
|
||||
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
||||
key = detach * key + (1 - detach) * key.detach()
|
||||
value = detach * value + (1 - detach) * value.detach()
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if self.train_q_out:
|
||||
# linear proj
|
||||
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
||||
else:
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SlicedAttnProcessor:
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
@@ -834,4 +1021,6 @@ AttentionProcessor = Union[
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
]
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, apply_forward_hook, deprecate
|
||||
from ..utils import BaseOutput, apply_forward_hook
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
@@ -123,16 +123,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
@property
|
||||
def block_out_channels(self):
|
||||
deprecate(
|
||||
"block_out_channels",
|
||||
"1.0.0",
|
||||
"Accessing `block_out_channels` directly via vae.block_out_channels is deprecated. Please use `vae.config.block_out_channels instead`",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.block_out_channels
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Encoder, Decoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -119,6 +119,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -456,6 +457,7 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
# check channel order
|
||||
@@ -556,8 +558,20 @@ class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample *= conditioning_scale
|
||||
if guess_mode:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0
|
||||
scales *= conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample *= scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample *= conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device
|
||||
@@ -32,6 +32,7 @@ from ..utils import (
|
||||
WEIGHTS_NAME,
|
||||
_add_variant,
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
@@ -156,6 +157,24 @@ class ModelMixin(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
||||
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
||||
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
||||
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
||||
"""
|
||||
|
||||
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
||||
is_attribute = name in self.__dict__
|
||||
|
||||
if is_in_config and not is_attribute:
|
||||
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
||||
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
||||
return self._internal_dict[name]
|
||||
|
||||
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
||||
return super().__getattr__(name)
|
||||
|
||||
@property
|
||||
def is_gradient_checkpointing(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -225,7 +225,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
@@ -190,16 +190,6 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
@@ -216,16 +216,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, deprecate, logging
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import AttentionProcessor, AttnProcessor
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
@@ -248,7 +248,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
@@ -437,7 +437,18 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
if act_fn == "swish":
|
||||
self.conv_act = lambda x: F.silu(x)
|
||||
elif act_fn == "mish":
|
||||
self.conv_act = nn.Mish()
|
||||
elif act_fn == "silu":
|
||||
self.conv_act = nn.SiLU()
|
||||
elif act_fn == "gelu":
|
||||
self.conv_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
else:
|
||||
self.conv_norm_out = None
|
||||
self.conv_act = None
|
||||
@@ -447,16 +458,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use `unet.config.in_channels` instead",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
@@ -658,7 +659,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
@@ -672,6 +673,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# there might be better ways to encapsulate this.
|
||||
class_labels = class_labels.to(dtype=sample.dtype)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
|
||||
if self.config.class_embeddings_concat:
|
||||
|
||||
@@ -56,7 +56,7 @@ class RobertaSeriesConfig(XLMRobertaConfig):
|
||||
|
||||
|
||||
class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"logit_scale"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
base_model_prefix = "roberta"
|
||||
config_class = RobertaSeriesConfig
|
||||
@@ -65,6 +65,10 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.roberta = XLMRobertaModel(config)
|
||||
self.transformation = nn.Linear(config.hidden_size, config.project_dim)
|
||||
self.has_pre_transformation = getattr(config, "has_pre_transformation", False)
|
||||
if self.has_pre_transformation:
|
||||
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
||||
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
@@ -95,15 +99,26 @@ class RobertaSeriesModelWithTransformation(RobertaPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_hidden_states=True if self.has_pre_transformation else output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
projection_state = self.transformation(outputs.last_hidden_state)
|
||||
if self.has_pre_transformation:
|
||||
sequence_output2 = outputs["hidden_states"][-2]
|
||||
sequence_output2 = self.pre_LN(sequence_output2)
|
||||
projection_state2 = self.transformation_pre(sequence_output2)
|
||||
|
||||
return TransformationModelOutput(
|
||||
projection_state=projection_state,
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
return TransformationModelOutput(
|
||||
projection_state=projection_state2,
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
projection_state = self.transformation(outputs.last_hidden_state)
|
||||
return TransformationModelOutput(
|
||||
projection_state=projection_state,
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -57,6 +57,14 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
@@ -96,6 +96,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
@@ -503,7 +511,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -200,24 +201,24 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
||||
# .bin, .safetensors, ...
|
||||
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
||||
# -00001-of-00002
|
||||
transformers_index_format = "\d{5}-of-\d{5}"
|
||||
transformers_index_format = r"\d{5}-of-\d{5}"
|
||||
|
||||
if variant is not None:
|
||||
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
|
||||
variant_file_re = re.compile(
|
||||
f"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.fp16.json`
|
||||
variant_index_re = re.compile(
|
||||
f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
||||
)
|
||||
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
|
||||
non_variant_file_re = re.compile(
|
||||
f"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
# `text_encoder/pytorch_model.bin.index.json`
|
||||
non_variant_index_re = re.compile(f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
||||
|
||||
if variant is not None:
|
||||
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
||||
@@ -507,7 +508,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
setattr(self, name, module)
|
||||
|
||||
def __setattr__(self, name: str, value: Any):
|
||||
if hasattr(self, name) and hasattr(self.config, name):
|
||||
if name in self.__dict__ and hasattr(self.config, name):
|
||||
# We need to overwrite the config if name exists in config
|
||||
if isinstance(getattr(self.config, name), (tuple, list)):
|
||||
if value is not None and self.config[name][0] is not None:
|
||||
@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
self.save_config(save_directory)
|
||||
|
||||
model_index_dict = dict(self.config)
|
||||
model_index_dict.pop("_class_name")
|
||||
model_index_dict.pop("_diffusers_version")
|
||||
model_index_dict.pop("_class_name", None)
|
||||
model_index_dict.pop("_diffusers_version", None)
|
||||
model_index_dict.pop("_module", None)
|
||||
|
||||
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
||||
@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
return True
|
||||
|
||||
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
model_cls = sub_model.__class__
|
||||
@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
save_method_name = None
|
||||
# search for the model's base class in LOADABLE_CLASSES
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
library = importlib.import_module(library_name)
|
||||
if library_name in sys.modules:
|
||||
library = importlib.import_module(library_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
|
||||
)
|
||||
|
||||
for base_class, save_load_methods in library_classes.items():
|
||||
class_candidate = getattr(library, base_class, None)
|
||||
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
||||
@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if save_method_name is not None:
|
||||
break
|
||||
|
||||
if save_method_name is None:
|
||||
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
|
||||
# make sure that unsaveable components are not tried to be loaded afterward
|
||||
self.register_to_config(**{pipeline_component_name: (None, None)})
|
||||
continue
|
||||
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
|
||||
# Call the save method with the argument safe_serialization only if it's supported
|
||||
@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
# finally save the config
|
||||
self.save_config(save_directory)
|
||||
|
||||
def to(
|
||||
self,
|
||||
torch_device: Optional[Union[str, torch.device]] = None,
|
||||
@@ -635,26 +648,25 @@ class DiffusionPipeline(ConfigMixin):
|
||||
)
|
||||
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
|
||||
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
||||
for name in module_names:
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
module.to(torch_device, torch_dtype)
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(torch_device) in ["cpu"]
|
||||
and not silence_dtype_warnings
|
||||
and not is_offloaded
|
||||
):
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
for module in modules:
|
||||
module.to(torch_device, torch_dtype)
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(torch_device) in ["cpu"]
|
||||
and not silence_dtype_warnings
|
||||
and not is_offloaded
|
||||
):
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
@@ -664,12 +676,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
`torch.device`: The torch device on which the pipeline is located.
|
||||
"""
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
|
||||
for module in modules:
|
||||
return module.device
|
||||
|
||||
for name in module_names:
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
return module.device
|
||||
return torch.device("cpu")
|
||||
|
||||
@classmethod
|
||||
@@ -1046,7 +1058,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
||||
if return_cached_folder:
|
||||
message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`."
|
||||
deprecate("return_cached_folder", "0.17.0", message, take_from=kwargs)
|
||||
deprecate("return_cached_folder", "0.17.0", message)
|
||||
return model, cached_folder
|
||||
|
||||
return model
|
||||
@@ -1438,13 +1450,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
|
||||
|
||||
for module_name in module_names:
|
||||
module = getattr(self, module_name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_mem_eff(module)
|
||||
for module in modules:
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
@@ -1471,10 +1482,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def set_attention_slice(self, slice_size: Optional[int]):
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
module_names = [m for m in module_names if hasattr(self, m)]
|
||||
module_names, _ = self._get_signature_keys(self)
|
||||
modules = [getattr(self, n, None) for n in module_names]
|
||||
modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")]
|
||||
|
||||
for module_name in module_names:
|
||||
module = getattr(self, module_name)
|
||||
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size)
|
||||
for module in modules:
|
||||
module.set_attention_slice(slice_size)
|
||||
|
||||
@@ -31,33 +31,30 @@ from transformers import (
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
PriorTransformer,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
PriorTransformer,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
UnCLIPScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
from ...utils import is_omegaconf_available, is_safetensors_available, logging
|
||||
from ...utils.import_utils import BACKENDS_MAPPING
|
||||
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from ..paint_by_example import PaintByExampleImageEncoder
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -990,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
clip_stats_path: Optional[str] = None,
|
||||
controlnet: Optional[bool] = None,
|
||||
load_safety_checker: bool = True,
|
||||
) -> StableDiffusionPipeline:
|
||||
pipeline_class: DiffusionPipeline = None,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||
config file.
|
||||
@@ -1018,6 +1016,8 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
model_type (`str`, *optional*, defaults to `None`):
|
||||
The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
|
||||
"FrozenCLIPEmbedder", "PaintByExample"]`.
|
||||
is_img2img (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model should be loaded as an img2img pipeline.
|
||||
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
||||
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
||||
@@ -1026,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
Whether the attention computation should always be upcasted. This is necessary when running stable
|
||||
diffusion 2.1.
|
||||
device (`str`, *optional*, defaults to `None`):
|
||||
The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
|
||||
in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
|
||||
StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
The device to use. Pass `None` to determine automatically.
|
||||
from_safetensors (`str`, *optional*, defaults to `False`):
|
||||
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
||||
Whether to load the safety checker or not. Defaults to `True`.
|
||||
pipeline_class (`str`, *optional*, defaults to `None`):
|
||||
The pipeline class to use. Pass `None` to determine automatically.
|
||||
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||
"""
|
||||
|
||||
# import pipelines here to avoid circular import error when using from_ckpt method
|
||||
from diffusers import (
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
|
||||
if pipeline_class is None:
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
|
||||
if prediction_type == "v-prediction":
|
||||
prediction_type = "v_prediction"
|
||||
|
||||
@@ -1193,7 +1210,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -1293,7 +1310,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -1320,7 +1337,7 @@ def download_controlnet_from_original_ckpt(
|
||||
upcast_attention: Optional[bool] = None,
|
||||
device: str = None,
|
||||
from_safetensors: bool = False,
|
||||
) -> StableDiffusionPipeline:
|
||||
) -> DiffusionPipeline:
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
|
||||
@@ -528,7 +528,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
+1
-1
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
|
||||
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
|
||||
... )
|
||||
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
|
||||
... )
|
||||
>>> params["controlnet"] = controlnet_params
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -53,13 +53,21 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
+10
-8
@@ -14,7 +14,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -76,7 +76,7 @@ class AttentionStore:
|
||||
|
||||
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
||||
if self.cur_att_layer >= 0 and is_cross:
|
||||
if attn.shape[1] == self.attn_res**2:
|
||||
if attn.shape[1] == np.prod(self.attn_res):
|
||||
self.step_store[place_in_unet].append(attn)
|
||||
|
||||
self.cur_att_layer += 1
|
||||
@@ -98,7 +98,7 @@ class AttentionStore:
|
||||
attention_maps = self.get_average_attention()
|
||||
for location in from_where:
|
||||
for item in attention_maps[location]:
|
||||
cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1])
|
||||
cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
|
||||
out.append(cross_maps)
|
||||
out = torch.cat(out, dim=0)
|
||||
out = out.sum(0) / out.shape[0]
|
||||
@@ -109,7 +109,7 @@ class AttentionStore:
|
||||
self.step_store = self.get_empty_store()
|
||||
self.attention_store = {}
|
||||
|
||||
def __init__(self, attn_res=16):
|
||||
def __init__(self, attn_res):
|
||||
"""
|
||||
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
|
||||
process
|
||||
@@ -724,7 +724,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
max_iter_to_alter: int = 25,
|
||||
thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
|
||||
scale_factor: int = 20,
|
||||
attn_res: int = 16,
|
||||
attn_res: Optional[Tuple[int]] = (16, 16),
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -796,8 +796,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
|
||||
scale_factor (`int`, *optional*, default to 20):
|
||||
Scale factor that controls the step size of each Attend and Excite update.
|
||||
attn_res (`int`, *optional*, default to 16):
|
||||
The resolution of most semantic attention map.
|
||||
attn_res (`tuple`, *optional*, default computed from width and height):
|
||||
The 2D resolution of the semantic attention map.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -870,7 +870,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
self.attention_store = AttentionStore(attn_res=attn_res)
|
||||
if attn_res is None:
|
||||
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
|
||||
self.attention_store = AttentionStore(attn_res)
|
||||
self.register_attention_control()
|
||||
|
||||
# default config for step size from original repo
|
||||
|
||||
@@ -118,6 +118,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
||||
@@ -131,6 +132,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
timestep_cond,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
guess_mode,
|
||||
return_dict,
|
||||
)
|
||||
|
||||
@@ -154,6 +156,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
@@ -627,7 +632,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
)
|
||||
|
||||
def prepare_image(
|
||||
self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if not isinstance(image, torch.Tensor):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
@@ -664,7 +678,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
@@ -747,6 +761,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
guess_mode: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -819,6 +834,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
||||
corresponding scale as a list.
|
||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
||||
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
||||
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
@@ -883,6 +902,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
device=device,
|
||||
dtype=self.controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
elif isinstance(self.controlnet, MultiControlNetModel):
|
||||
images = []
|
||||
@@ -897,6 +917,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
device=device,
|
||||
dtype=self.controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
images.append(image_)
|
||||
@@ -934,15 +955,31 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
controlnet_latent_model_input = latents
|
||||
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
||||
else:
|
||||
controlnet_latent_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
latent_model_input,
|
||||
controlnet_latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states=controlnet_prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
||||
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
|
||||
@@ -23,7 +23,7 @@ from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
||||
@@ -55,13 +55,20 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
@@ -390,7 +397,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -92,13 +92,21 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
@@ -511,7 +519,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
@@ -138,13 +138,20 @@ def prepare_mask_and_masked_image(image, mask):
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
+23
-13
@@ -22,7 +22,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -41,17 +41,17 @@ from .safety_checker import StableDiffusionSafetyChecker
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
def preprocess_image(image, batch_size):
|
||||
w, h = image.size
|
||||
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
||||
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def preprocess_mask(mask, scale_factor=8):
|
||||
def preprocess_mask(mask, batch_size, scale_factor=8):
|
||||
if not isinstance(mask, torch.FloatTensor):
|
||||
mask = mask.convert("L")
|
||||
w, h = mask.size
|
||||
@@ -59,7 +59,7 @@ def preprocess_mask(mask, scale_factor=8):
|
||||
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
mask = np.vstack([mask[None]] * batch_size)
|
||||
mask = 1 - mask # repaint white, keep black
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask
|
||||
@@ -82,13 +82,23 @@ def preprocess_mask(mask, scale_factor=8):
|
||||
return mask
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionInpaintPipelineLegacy(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
@@ -507,18 +517,18 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator):
|
||||
def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator):
|
||||
image = image.to(device=self.device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
# Expand init_latents for batch_size and num_images_per_prompt
|
||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||
init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
|
||||
init_latents_orig = init_latents
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
@@ -649,9 +659,9 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
|
||||
|
||||
# 4. Preprocess image and mask
|
||||
if not isinstance(image, torch.FloatTensor):
|
||||
image = preprocess_image(image)
|
||||
image = preprocess_image(image, batch_size)
|
||||
|
||||
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
||||
mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -661,12 +671,12 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLo
|
||||
# 6. Prepare latent variables
|
||||
# encode the init image into latents and scale the latents
|
||||
latents, init_latents_orig, noise = self.prepare_latents(
|
||||
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
|
||||
image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent
|
||||
mask = mask_image.to(device=self.device, dtype=latents.dtype)
|
||||
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
||||
mask = torch.cat([mask] * num_images_per_prompt)
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
+9
-2
@@ -20,7 +20,7 @@ import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -61,13 +61,20 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
|
||||
as well as the following saving methods:
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
|
||||
+31
-27
@@ -36,6 +36,7 @@ from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
@@ -721,23 +722,31 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)]
|
||||
latents = torch.cat(latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
latents = self.vae.config.scaling_factor * latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
if batch_size != latents.shape[0]:
|
||||
if batch_size % latents.shape[0] == 0:
|
||||
# expand image_latents for batch_size
|
||||
deprecation_message = (
|
||||
f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial"
|
||||
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
|
||||
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
|
||||
" your script to pass as many initial images as text prompts to suppress this warning."
|
||||
)
|
||||
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
|
||||
additional_latents_per_image = batch_size // latents.shape[0]
|
||||
latents = torch.cat([latents] * additional_latents_per_image, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
latents = init_latents
|
||||
latents = torch.cat([latents], dim=0)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -759,23 +768,18 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
def auto_corr_loss(self, hidden_states, generator=None):
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
if batch_size > 1:
|
||||
raise ValueError("Only batch_size 1 is supported for now")
|
||||
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
# hidden_states must be shape [C,H,W] now
|
||||
reg_loss = 0.0
|
||||
for i in range(hidden_states.shape[0]):
|
||||
noise = hidden_states[i][None, None, :, :]
|
||||
while True:
|
||||
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
|
||||
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
|
||||
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
|
||||
for j in range(hidden_states.shape[1]):
|
||||
noise = hidden_states[i : i + 1, j : j + 1, :, :]
|
||||
while True:
|
||||
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
|
||||
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
|
||||
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
|
||||
|
||||
if noise.shape[2] <= 8:
|
||||
break
|
||||
noise = F.avg_pool2d(noise, kernel_size=2)
|
||||
if noise.shape[2] <= 8:
|
||||
break
|
||||
noise = F.avg_pool2d(noise, kernel_size=2)
|
||||
return reg_loss
|
||||
|
||||
def kl_divergence(self, hidden_states):
|
||||
|
||||
@@ -85,7 +85,10 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
|
||||
|
||||
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
||||
if has_nsfw_concept:
|
||||
images[idx] = np.zeros(images[idx].shape) # black image
|
||||
if torch.is_tensor(images) or torch.is_tensor(images[0]):
|
||||
images[idx] = torch.zeros_like(images[idx]) # black image
|
||||
else:
|
||||
images[idx] = np.zeros(images[idx].shape) # black image
|
||||
|
||||
if any(has_nsfw_concepts):
|
||||
logger.warning(
|
||||
|
||||
@@ -441,7 +441,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -413,9 +413,9 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
num_channels_latents = self.decoder.config.in_channels
|
||||
height = self.decoder.config.sample_size
|
||||
width = self.decoder.config.sample_size
|
||||
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
@@ -466,9 +466,9 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
channels = self.super_res_first.config.in_channels // 2
|
||||
height = self.super_res_first.config.sample_size
|
||||
width = self.super_res_first.config.sample_size
|
||||
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
|
||||
@@ -339,9 +339,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
num_channels_latents = self.decoder.config.in_channels
|
||||
height = self.decoder.config.sample_size
|
||||
width = self.decoder.config.sample_size
|
||||
|
||||
if decoder_latents is None:
|
||||
decoder_latents = self.prepare_latents(
|
||||
@@ -393,9 +393,9 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
channels = self.super_res_first.config.in_channels // 2
|
||||
height = self.super_res_first.config.sample_size
|
||||
width = self.super_res_first.config.sample_size
|
||||
|
||||
if super_res_latents is None:
|
||||
super_res_latents = self.prepare_latents(
|
||||
|
||||
@@ -18,7 +18,7 @@ from ...models.dual_transformer_2d import DualTransformer2DModel
|
||||
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from ...models.transformer_2d import Transformer2DModel
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import deprecate, logging
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -345,7 +345,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
@@ -534,7 +534,18 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
if act_fn == "swish":
|
||||
self.conv_act = lambda x: F.silu(x)
|
||||
elif act_fn == "mish":
|
||||
self.conv_act = nn.Mish()
|
||||
elif act_fn == "silu":
|
||||
self.conv_act = nn.SiLU()
|
||||
elif act_fn == "gelu":
|
||||
self.conv_act = nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
else:
|
||||
self.conv_norm_out = None
|
||||
self.conv_act = None
|
||||
@@ -544,19 +555,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
||||
)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
deprecate(
|
||||
"in_channels",
|
||||
"1.0.0",
|
||||
(
|
||||
"Accessing `in_channels` directly via unet.in_channels is deprecated. Please use"
|
||||
" `unet.config.in_channels` instead"
|
||||
),
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.in_channels
|
||||
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
@@ -758,7 +756,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
@@ -772,6 +770,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# there might be better ways to encapsulate this.
|
||||
class_labels = class_labels.to(dtype=sample.dtype)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
|
||||
if self.config.class_embeddings_concat:
|
||||
|
||||
+1
-1
@@ -533,7 +533,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.image_unet.in_channels
|
||||
num_channels_latents = self.image_unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
+1
-1
@@ -378,7 +378,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.image_unet.in_channels
|
||||
num_channels_latents = self.image_unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
+1
-1
@@ -452,7 +452,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.image_unet.in_channels
|
||||
num_channels_latents = self.image_unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
|
||||
@@ -22,7 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, deprecate, randn_tensor
|
||||
from ..utils import BaseOutput, randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
@@ -162,21 +162,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.custom_timesteps = False
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
@property
|
||||
def num_train_timesteps(self):
|
||||
deprecate(
|
||||
"num_train_timesteps",
|
||||
"1.0.0",
|
||||
"Accessing `num_train_timesteps` directly via scheduler.num_train_timesteps is deprecated. Please use `scheduler.config.num_train_timesteps instead`",
|
||||
standard_warn=False,
|
||||
)
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
@@ -191,31 +182,62 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
num_inference_steps (`Optional[int]`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps are moved to.
|
||||
custom_timesteps (`List[int]`, optional):
|
||||
custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
|
||||
must be `None`.
|
||||
|
||||
"""
|
||||
if num_inference_steps is not None and timesteps is not None:
|
||||
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
if timesteps is not None:
|
||||
for i in range(1, len(timesteps)):
|
||||
if timesteps[i] >= timesteps[i - 1]:
|
||||
raise ValueError("`custom_timesteps` must be in descending order.")
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
if timesteps[0] >= self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`timesteps` must start before `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps}."
|
||||
)
|
||||
|
||||
timesteps = np.array(timesteps, dtype=np.int64)
|
||||
self.custom_timesteps = True
|
||||
else:
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
self.custom_timesteps = False
|
||||
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
||||
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
prev_t = t - self.config.num_train_timesteps // num_inference_steps
|
||||
prev_t = self.previous_timestep(t)
|
||||
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
|
||||
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
|
||||
@@ -314,8 +336,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
t = timestep
|
||||
num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
||||
|
||||
prev_t = self.previous_timestep(t)
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
@@ -428,3 +450,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def previous_timestep(self, timestep):
|
||||
if self.custom_timesteps:
|
||||
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||||
if index == self.timesteps.shape[0] - 1:
|
||||
prev_t = torch.tensor(-1)
|
||||
else:
|
||||
prev_t = self.timesteps[index + 1]
|
||||
else:
|
||||
num_inference_steps = (
|
||||
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
)
|
||||
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
||||
|
||||
return prev_t
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any, Dict, Optional, Union
|
||||
from packaging import version
|
||||
|
||||
|
||||
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True):
|
||||
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
|
||||
from .. import __version__
|
||||
|
||||
deprecated_kwargs = take_from
|
||||
@@ -32,7 +32,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn
|
||||
|
||||
if warning is not None:
|
||||
warning = warning + " " if standard_warn else ""
|
||||
warnings.warn(warning + message, FutureWarning, stacklevel=2)
|
||||
warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel)
|
||||
|
||||
if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
|
||||
call_frame = inspect.getouterframes(inspect.currentframe())[1]
|
||||
|
||||
@@ -2,21 +2,6 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class TextualInversionLoaderMixin(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -279,6 +279,16 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
def preprocess_image(image: PIL.Image, batch_size: int):
|
||||
w, h = image.size
|
||||
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
|
||||
if is_opencv_available():
|
||||
import cv2
|
||||
|
||||
@@ -46,7 +46,7 @@ def create_unet_lora_layers(unet: nn.Module):
|
||||
def create_text_encoder_lora_layers(text_encoder: nn.Module):
|
||||
text_lora_attn_procs = {}
|
||||
for name, module in text_encoder.named_modules():
|
||||
if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
|
||||
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
|
||||
text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None)
|
||||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
|
||||
return text_encoder_lora_layers
|
||||
@@ -71,6 +71,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
steps_offset=1,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
@@ -26,8 +26,8 @@ from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from diffusers.utils import logging, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
|
||||
|
||||
|
||||
class ModelUtilsTest(unittest.TestCase):
|
||||
@@ -155,6 +155,49 @@ class ModelTesterMixin:
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
||||
def test_getattr_is_correct(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
# save some things to test
|
||||
model.dummy_attribute = 5
|
||||
model.register_to_config(test_attribute=5)
|
||||
|
||||
logger = logging.get_logger("diffusers.models.modeling_utils")
|
||||
# 30 for warning
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
assert hasattr(model, "dummy_attribute")
|
||||
assert getattr(model, "dummy_attribute") == 5
|
||||
assert model.dummy_attribute == 5
|
||||
|
||||
# no warning should be thrown
|
||||
assert cap_logger.out == ""
|
||||
|
||||
logger = logging.get_logger("diffusers.models.modeling_utils")
|
||||
# 30 for warning
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
assert hasattr(model, "save_pretrained")
|
||||
fn = model.save_pretrained
|
||||
fn_1 = getattr(model, "save_pretrained")
|
||||
|
||||
assert fn == fn_1
|
||||
# no warning should be thrown
|
||||
assert cap_logger.out == ""
|
||||
|
||||
# warning should be thrown
|
||||
with self.assertWarns(FutureWarning):
|
||||
assert model.test_attribute == 5
|
||||
|
||||
with self.assertWarns(FutureWarning):
|
||||
assert getattr(model, "test_attribute") == 5
|
||||
|
||||
with self.assertRaises(AttributeError) as error:
|
||||
model.does_not_exist
|
||||
|
||||
assert str(error.exception) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from diffusers import UNet1DModel
|
||||
from diffusers.utils import floats_tensor, slow, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
from diffusers import UNet2DModel
|
||||
from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, LoRAAttnProcessor
|
||||
from diffusers.utils import (
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
@@ -34,7 +34,7 @@ from diffusers.utils import (
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -68,6 +68,55 @@ def create_lora_layers(model, mock_weights: bool = True):
|
||||
return lora_attn_procs
|
||||
|
||||
|
||||
def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
train_kv = True
|
||||
train_q_out = True
|
||||
custom_diffusion_attn_procs = {}
|
||||
|
||||
st = model.state_dict()
|
||||
for name, _ in model.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
layer_name = name.split(".processor")[0]
|
||||
weights = {
|
||||
"to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"],
|
||||
"to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"],
|
||||
}
|
||||
if train_q_out:
|
||||
weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"]
|
||||
weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"]
|
||||
weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"]
|
||||
if cross_attention_dim is not None:
|
||||
custom_diffusion_attn_procs[name] = CustomDiffusionAttnProcessor(
|
||||
train_kv=train_kv,
|
||||
train_q_out=train_q_out,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
).to(model.device)
|
||||
custom_diffusion_attn_procs[name].load_state_dict(weights)
|
||||
if mock_weights:
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
custom_diffusion_attn_procs[name].to_k_custom_diffusion.weight += 1
|
||||
custom_diffusion_attn_procs[name].to_v_custom_diffusion.weight += 1
|
||||
else:
|
||||
custom_diffusion_attn_procs[name] = CustomDiffusionAttnProcessor(
|
||||
train_kv=False,
|
||||
train_q_out=False,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
del st
|
||||
return custom_diffusion_attn_procs
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
@@ -569,6 +618,96 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
def test_custom_diffusion_processors(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample1 = model(**inputs_dict).sample
|
||||
|
||||
custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False)
|
||||
|
||||
# make sure we can set a list of attention processors
|
||||
model.set_attn_processor(custom_diffusion_attn_procs)
|
||||
model.to(torch_device)
|
||||
|
||||
# test that attn processors can be set to itself
|
||||
model.set_attn_processor(model.attn_processors)
|
||||
|
||||
with torch.no_grad():
|
||||
sample2 = model(**inputs_dict).sample
|
||||
|
||||
assert (sample1 - sample2).abs().max() < 1e-4
|
||||
|
||||
def test_custom_diffusion_save_load(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
old_sample = model(**inputs_dict).sample
|
||||
|
||||
custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False)
|
||||
model.set_attn_processor(custom_diffusion_attn_procs)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
|
||||
with torch.no_grad():
|
||||
new_sample = new_model(**inputs_dict).sample
|
||||
|
||||
assert (sample - new_sample).abs().max() < 1e-4
|
||||
|
||||
# custom diffusion and no custom diffusion should be the same
|
||||
assert (sample - old_sample).abs().max() < 1e-4
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_custom_diffusion_xformers_on_off(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False)
|
||||
model.set_attn_processor(custom_diffusion_attn_procs)
|
||||
|
||||
# default
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
on_sample = model(**inputs_dict).sample
|
||||
|
||||
model.disable_xformers_memory_efficient_attention()
|
||||
off_sample = model(**inputs_dict).sample
|
||||
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils import (
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -22,7 +22,7 @@ from parameterized import parameterized
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user