Compare commits

..

46 Commits

Author SHA1 Message Date
Dhruv Nair f8c53ee022 update 2024-01-25 06:24:04 +00:00
Dhruv Nair d1272550d6 update 2024-01-24 17:46:37 +00:00
Dhruv Nair 75001f620e update 2024-01-24 17:44:26 +00:00
Dhruv Nair fee93c81eb [Refactor] Update from single file (#6428)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update'

* update

* update

* update

* update

* update

* update

* up

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* up

* update

* update

* update

* update

* update'

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* clean

* update

* update

* clean up

* clean up

* update

* clean

* clean

* update

* updaet

* clean up

* fix docs

* update

* update

* Revert "update"

This reverts commit dbfb8f1ea9.

* update

* update

* update

* update

* fix controlnet

* fix scheduler

* fix controlnet tests
2024-01-23 14:42:03 +05:30
Sayak Paul 5308cce994 [Tests] Test for passing local config file to from_single_file() (#6638)
make config file local too.
2024-01-23 14:21:23 +05:30
YiYi Xu 318556b20e fix dpm related slow test failure (#6680)
fix

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-01-22 18:52:05 -10:00
Dhruv Nair 6620eda357 Standardise outputs for video pipelines (#6626)
* update

* update

* update

* update

* update

* update

* update

* clean up

* clean up
2024-01-23 10:07:07 +05:30
Sayak Paul 1f0705adcf [Big refactor] move unets to unets module 🦋 (#6630)
* move unets to  module 🦋

* parameterize unet-level import.

* fix flax unet2dcondition model import

* models __init__

* mildly depcrecating models.unet_2d_blocks in favor of models.unets.unet_2d_blocks.

* noqa

* correct depcrecation behaviour

* inherit from the actual classes.

* Empty-Commit

* backwards compatibility for unet_2d.py

* backward compatibility for unet_2d_condition

* bc for unet_1d

* bc for unet_1d_blocks
2024-01-23 08:57:58 +05:30
M. Tolga Cangöz 5e96333cb2 Update README (#6669)
Update number of checkpoints and repositories in README
2024-01-22 08:08:07 -08:00
Sayak Paul da95a28ff6 [Diffusion DPO] apply fixes from #6547 (#6668)
apply fixes from #6547
2024-01-22 20:14:54 +05:30
Dhruv Nair d66d554dc2 Add tearDown method to LoRA tests. (#6660)
* update

* update
2024-01-22 14:00:37 +05:30
Junsong Chen c7df846dec add Sa-Solver (#5975)
* add Sa-Solver



---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: scxue <xueshuchen17@mails.ucas.edu.cn>
Co-authored-by: jschen <chenjunsong4@h-partners.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-01-21 21:37:44 -10:00
Vinh H. Pham 8e7bbfbe5a add padding_mask_crop to all inpaint pipelines (#6360)
* add padding_mask_crop
---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-01-21 19:42:22 -10:00
YiYi Xu e2773c6255 fix SDXL-kdiffusion tests (#6647)
🤞🤞🤞

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-01-19 17:37:29 -10:00
YiYi Xu ac61eefc9f fix DPM Scheduler with use_karras_sigmas option (#6477)
* fix

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2024-01-19 07:08:22 -10:00
HelloWorldBeginner f95615b823 Fixed the bug related to saving DeepSpeed models. (#6628)
* Fixed the bug related to saving DeepSpeed models.

* Add information about training SD models using DeepSpeed to the README.

* Apply suggestions from code review

---------

Co-authored-by: mhh001 <mahonghao1@huawei.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-01-19 19:21:57 +05:30
SangKim a9288b49c9 Modularize InstructPix2Pix SDXL inferencing during and after training in examples (#6569) 2024-01-19 15:47:34 +05:30
elucida c54419658b refactor: extract init/forward function in UNet2DConditionModel (#6478)
* - extract function for stage in UNet2DConditionModel init & forward
- Add new function get_mid_block() to unet_2d_blocks.py

* add type hint to get_mid_block aligned with get_up_block and get_down_block; rename _set_xxx function

* add type hint and  use keyword arguments

* remove `copy from` in versatile diffusion
2024-01-19 12:13:34 +02:00
Aryan V S 6382663dc8 [Community] Experimental AnimateDiff Image to Video (open to improvements) (#6509)
* add animatediff img2vid

* fix

* Update examples/community/README.md

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix code snippet between ip adapter face id and animatediff img2vid

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2024-01-19 12:05:41 +02:00
spezialspezial 58b8dce129 Update convert_from_ckpt.py / read checkpoint config yaml contents (#6633)
Update convert_from_ckpt.py / read config yaml contents

Added missing reading of config yaml file contents
2024-01-19 08:25:03 +05:30
Dhruv Nair a65ca8a059 Fix failing tests due to Posix Path (#6627)
update
2024-01-18 19:19:57 +05:30
Steven Liu 5ca062e011 [docs] Fix missing API function (#6604)
fix?
2024-01-17 13:59:09 -08:00
Linoy Tsaban 619e3ab6f6 [bug fix] advanced dreambooth lora sdxl - fixes bugs described in #6486 (#6599)
* fixes bugs:
1. redundant retraction
2. param clone
3. stopping optimization of text encoder params

* param upscaling

* style
2024-01-17 20:11:45 +05:30
Patrick von Platen 9e2804f720 Update pr_test_peft_backend.yml to use 1 process for testing (#6613) 2024-01-17 19:25:30 +05:30
Aryan V S 9112028ed8 FreeInit (#6315)
* freeinit

* update freeinit implementation based on review

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* fix

* another fix

* refactor

* fix timesteps missing bug

* apply suggestions from review

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* add test for freeinit

* apply suggestions from review

Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com>

* refactor

* fix test

* fix tensor not on same device

* update

* remove return_intermediate_results

* fix broken freeinit test

* update animatediff docs

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-01-17 17:17:07 +05:30
Steve Rhoades dce06680d2 Fixes torch.compile() compatible training (#6589)
resolve conflicts
2024-01-17 07:47:03 +05:30
Bhavay Malhotra dd63168319 Update installation.md (#6438)
* Update installation.md

* Update installation.md

* Update installation.md
2024-01-16 13:41:38 -08:00
Celestial Phineas 1040dfd9cc [Fix] Multiple image conditionings in a single batch for StableDiffusionControlNetPipeline (#6334)
* [Fix] Multiple image conditionings in a single batch for `StableDiffusionControlNetPipeline`.

* Refactor `check_inputs` in `StableDiffusionControlNetPipeline` to avoid redundant codes.

* Make the behavior of MultiControlNetModel to be the same to the original ControlNetModel

* Keep the code change minimum for nested list support

* Add fast test `test_inference_nested_image_input`

* Remove redundant check for nested image condition in `check_inputs`

Remove `len(image) == len(prompt)` check out of `check_image()`

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Better `ValueError` message for incompatible nested image list size

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Fix syntax error in `check_inputs`

* Remove warning message for multi-ControlNets with multiple prompts

* Fix a typo in test_controlnet.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Add test case for multiple prompts, single image conditioning in `StableDiffusionMultiControlNetPipelineFastTests`

* Improved `ValueError` message for nested `controlnet_conditioning_scale`

* Documenting the behavior of image list as `StableDiffusionControlNetPipeline` input

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-01-16 10:40:55 -10:00
Sayak Paul 49a4b377c1 remove omegaconf from the residues 👋 (#6600)
remove omegaconf 👋
2024-01-16 21:55:41 +05:30
JuanCarlosPi dff35a86e4 Change in ip-adapter docs. CLIPVisionModelWithProjection should be im… (#6597)
Change in ip-adapter docs. CLIPVisionModelWithProjection should be imported from transformers, not diffusers
2024-01-16 08:18:13 -08:00
Yondon Fu 8842bcadb9 [SVD] Return np.ndarray when output_type="np" (#6507)
[SVD] Fix output_type="np"
2024-01-16 14:51:36 +05:30
Steve Rhoades 181280baba Fixes training resuming: Advanced Dreambooth LoRa Training (#6566)
* Fixes #6418 Advanced Dreambooth LoRa Training

* change order of import to fix nit

* fix nit, use cast_training_params

* remove torch.compile fix, will move to a new PR

* remove unnecessary import
2024-01-16 14:30:49 +05:30
Charchit Sharma 53f498d2a4 Use of Posix to better support Windows compatibility in testing_utils (#6587)
* changes in utils

* removed loc
2024-01-16 10:27:29 +02:00
Charchit Sharma 990860911f change to posix for better Windows support for lora loaders (#6590)
* posix lora

* changes and style fix
2024-01-16 13:46:29 +05:30
Fabio Rigano 23eed39702 Fix path generation in IP Adapter (#6564)
* Fix path generation on Windows

* Update set_default_attn_processors

* Use pathlib

* Fix quality

* Fix copy

* Revert changes in set_default_attn_processors

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-01-16 11:25:22 +05:30
YiYi Xu fefed44543 update slow test for SDXL k-diffusion pipeline (#6588)
update expected slice
2024-01-15 18:54:33 -10:00
Dong 814f56d2fe 🐛 fix ip-adapter controlnet img2img missing code (#6528)
* 🐛 fix ip-adapter controlnet img2img missing code

* 📝 edit test

* 📝 edit test

* 📝 run make style and quality

* 🎨 remove slow tests
2024-01-15 16:41:09 -10:00
SangKim 96d6e16550 Enable image resizing to adjust its height and width in StableDiffusionXLInstructPix2PixPipeline (#6581)
* Enable image resizing to adjust its height and width in StableDiffusionXLInstructPix2PixPipeline

* Ensure that validation is performed at every 'validation_step', not at every step
2024-01-16 07:50:34 +05:30
Aryan V S c11de13588 [training] fix training resuming problem for fp16 (SD LoRA DreamBooth) (#6554)
* fix training resume

* update

* update
2024-01-16 07:27:06 +05:30
Patrick von Platen 357855f8fc [Docs] Fix controlnet indent (#6578) 2024-01-15 18:12:55 +02:00
Fabio Rigano f825221b5d [Community Pipeline] IPAdapter FaceID (#6276)
* Add support for IPAdapter FaceID

* Add docs

* Move subfolder to kwargs

* Fix quality

* Fix image encoder loading

* Fix loading + add test

* Move to community folder

* Fix style

* Revert constant update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-01-15 16:43:54 +02:00
Aryan V S 119d734f6e [AnimateDiff+Controlnet] Fix multicontrolnet support (#6551)
* fix multicontrolnet support

* update README with multicontrolnet example
2024-01-15 16:36:54 +02:00
Sayak Paul cb4b3f0b78 [OmegaConf] replace it with yaml (#6488)
* remove omegaconf from convert_from_ckpt.

* remove from single_file.

* change to string based ubscription.

* style

* okay

* fix: vae_param

* no . indexing.

* style

* style

* turn getattrs into explicit if/else

* style

* propagate changes to ldm_uncond.

* propagate to gligen

* propagate to if.

* fix: quotes.

* propagate to audioldm.

* propagate to audioldm2

* propagate to musicldm.

* propagate to vq_diffusion

* propagate to zero123.

* remove omegaconf from diffusers codebase.
2024-01-15 20:02:10 +05:30
Haofan Wang 3d574b3bbe Fix a bug of flip in SDXL training script (#6547)
* Update train_text_to_image_sdxl.py

* Update train_text_to_image_lora_sdxl.py
2024-01-15 16:28:04 +02:00
Charchit Sharma 09903774d9 Make T2I Adapter SDXL Training Script torch.compile compatible (#6577)
update for t2i_adapter
2024-01-15 19:42:56 +05:30
dependabot[bot] d6a70d8ba8 Bump jinja2 from 3.1.2 to 3.1.3 in /examples/research_projects/realfill (#6539)
Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.2 to 3.1.3.
- [Release notes](https://github.com/pallets/jinja/releases)
- [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst)
- [Commits](https://github.com/pallets/jinja/compare/3.1.2...3.1.3)

---
updated-dependencies:
- dependency-name: jinja2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-01-15 16:10:10 +02:00
159 changed files with 14288 additions and 9355 deletions
+1 -1
View File
@@ -59,7 +59,7 @@ jobs:
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora/test_lora_layers_peft.py
+2 -2
View File
@@ -77,7 +77,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
## Quickstart
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 16000+ checkpoints):
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 19000+ checkpoints):
```python
from diffusers import DiffusionPipeline
@@ -219,7 +219,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
- https://github.com/deep-floyd/IF
- https://github.com/bentoml/BentoML
- https://github.com/bmaltais/kohya_ss
- +7000 other amazing GitHub repositories 💪
- +8000 other amazing GitHub repositories 💪
Thank you for using us ❤️.
@@ -40,7 +40,6 @@ RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
numpy \
scipy \
tensorboard \
transformers \
omegaconf
transformers
CMD ["/bin/bash"]
-1
View File
@@ -40,7 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
scipy \
tensorboard \
transformers \
omegaconf \
pytorch-lightning
CMD ["/bin/bash"]
@@ -40,7 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
scipy \
tensorboard \
transformers \
omegaconf \
xformers
CMD ["/bin/bash"]
+2 -2
View File
@@ -30,8 +30,8 @@ To learn more about how to load single file weights, see the [Load different Sta
## FromOriginalVAEMixin
[[autodoc]] loaders.single_file.FromOriginalVAEMixin
[[autodoc]] loaders.autoencoder.FromOriginalVAEMixin
## FromOriginalControlnetMixin
[[autodoc]] loaders.single_file.FromOriginalControlnetMixin
[[autodoc]] loaders.controlnet.FromOriginalControlNetMixin
@@ -33,6 +33,9 @@ model = AutoencoderKL.from_single_file(url)
## AutoencoderKL
[[autodoc]] AutoencoderKL
- decode
- encode
- all
## AutoencoderKLOutput
+1 -1
View File
@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNetMotionModel
## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput
+1 -1
View File
@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNet1DModel
## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput
[[autodoc]] models.unets.unet_1d.UNet1DOutput
+3 -3
View File
@@ -22,10 +22,10 @@ The abstract from the paper is:
[[autodoc]] UNet2DConditionModel
## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
[[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput
## FlaxUNet2DConditionModel
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionModel
## FlaxUNet2DConditionOutput
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput
+1 -1
View File
@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNet2DModel
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput
[[autodoc]] models.unets.unet_2d.UNet2DOutput
+1 -1
View File
@@ -22,4 +22,4 @@ The abstract from the paper is:
[[autodoc]] UNet3DConditionModel
## UNet3DConditionOutput
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
[[autodoc]] models.unets.unet_3d_condition.UNet3DConditionOutput
@@ -235,6 +235,62 @@ export_to_gif(frames, "animation.gif")
</tr>
</table>
## Using FreeInit
[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.
FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper.
The following example demonstrates the usage of FreeInit.
```python
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from diffusers.utils import export_to_gif
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
beta_schedule="linear",
clip_sample=False,
timestep_spacing="linspace",
steps_offset=1
)
# enable memory savings
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
# enable FreeInit
# Refer to the enable_free_init documentation for a full list of configurable parameters
pipe.enable_free_init(method="butterworth", use_fast_sampling=True)
# run inference
output = pipe(
prompt="a panda playing a guitar, on a boat, in the ocean, high quality",
negative_prompt="bad quality, worse quality",
num_frames=16,
guidance_scale=7.5,
num_inference_steps=20,
generator=torch.Generator("cpu").manual_seed(666),
)
# disable FreeInit
pipe.disable_free_init()
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
<Tip warning={true}>
FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
</Tip>
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
@@ -248,6 +304,8 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- __call__
- enable_freeu
- disable_freeu
- enable_free_init
- disable_free_init
- enable_vae_slicing
- disable_vae_slicing
- enable_vae_tiling
+2
View File
@@ -37,8 +37,10 @@ source .env/bin/activate
You should also install 🤗 Transformers because 🤗 Diffusers relies on its models:
<frameworkcontent>
<pt>
Note - PyTorch only supports Python 3.8 - 3.11 on Windows.
```bash
pip install diffusers["torch"] transformers
```
@@ -344,7 +344,8 @@ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-a
IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it.
```py
from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection
from diffusers import AutoPipelineForText2Image
from transformers import CLIPVisionModelWithProjection
import torch
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+1 -1
View File
@@ -26,7 +26,7 @@ Before you begin, make sure you have the following libraries installed:
```py
# uncomment to install the necessary libraries in Colab
#!pip install -q diffusers transformers accelerate omegaconf invisible-watermark>=0.2.0
#!pip install -q diffusers transformers accelerate invisible-watermark>=0.2.0
```
<Tip warning={true}>
+1 -1
View File
@@ -23,7 +23,7 @@ Before you begin, make sure you have the following libraries installed:
```py
# uncomment to install the necessary libraries in Colab
#!pip install -q diffusers transformers accelerate omegaconf
#!pip install -q diffusers transformers accelerate
```
## Load model checkpoints
@@ -38,7 +38,7 @@ from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft import LoraConfig, set_peft_model_state_dict
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
@@ -58,15 +58,17 @@ from diffusers import (
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
from diffusers.utils import (
check_min_version,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1277,7 +1279,7 @@ def main(args):
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
@@ -1286,22 +1288,16 @@ def main(args):
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
param.requires_grad = False
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
for model in models:
for param in model.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
@@ -1313,14 +1309,14 @@ def main(args):
text_encoder_two_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder:
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
elif isinstance(model, type(unwrap_model(text_encoder_two))):
if args.train_text_encoder:
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
@@ -1348,27 +1344,44 @@ def main(args):
while len(models) > 0:
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
if args.train_text_encoder:
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
_set_state_dict_into_text_encoder(
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [unet_]
if args.train_text_encoder:
models.extend([text_encoder_one_, text_encoder_two_])
cast_training_params(models)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -1383,6 +1396,13 @@ def main(args):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.extend([text_encoder_one, text_encoder_two])
cast_training_params(models, dtype=torch.float32)
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
if args.train_text_encoder:
@@ -1705,19 +1725,19 @@ def main(args):
num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
elif args.train_text_encoder_ti: # args.train_text_encoder_ti
num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
# flag used for textual inversion
pivoted = False
for epoch in range(first_epoch, args.num_train_epochs):
# if performing any kind of optimization of text_encoder params
if args.train_text_encoder or args.train_text_encoder_ti:
if epoch == num_train_epochs_text_encoder:
print("PIVOT HALFWAY", epoch)
# stopping optimization of text_encoder params
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0
# this flag is used to reset the optimizer to optimize only on unet params
pivoted = True
else:
# still optimizng the text encoder
# still optimizing the text encoder
text_encoder_one.train()
text_encoder_two.train()
# set top parameter requires_grad = True for gradient checkpointing works
@@ -1727,6 +1747,12 @@ def main(args):
unet.train()
for step, batch in enumerate(train_dataloader):
if pivoted:
# stopping optimization of text_encoder params
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0
with accelerator.accumulate(unet):
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
@@ -1865,8 +1891,7 @@ def main(args):
# every step, we reset the embeddings to the original embeddings.
if args.train_text_encoder_ti:
for idx, text_encoder in enumerate(text_encoders):
embedding_handler.retract_embeddings()
embedding_handler.retract_embeddings()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
+163 -3
View File
@@ -58,6 +58,9 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender_A_Video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -2989,7 +2992,7 @@ pipe = DiffusionPipeline.from_pretrained(
custom_pipeline="pipeline_animatediff_controlnet",
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
model_id, subfolder="scheduler", beta_schedule="linear", clip_sample=False, timestep_spacing="linspace", steps_offset=1
)
pipe.enable_vae_slicing()
@@ -3005,7 +3008,7 @@ result = pipe(
width=512,
height=768,
conditioning_frames=conditioning_frames,
num_inference_steps=12,
num_inference_steps=20,
).frames[0]
from diffusers.utils import export_to_gif
@@ -3029,6 +3032,79 @@ export_to_gif(result.frames[0], "result.gif")
</tr>
</table>
You can also use multiple controlnets at once!
```python
import torch
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter
from diffusers.pipelines import DiffusionPipeline
from diffusers.schedulers import DPMSolverMultistepScheduler
from PIL import Image
motion_id = "guoyww/animatediff-motion-adapter-v1-5-2"
adapter = MotionAdapter.from_pretrained(motion_id)
controlnet1 = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16)
controlnet2 = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
pipe = DiffusionPipeline.from_pretrained(
model_id,
motion_adapter=adapter,
controlnet=[controlnet1, controlnet2],
vae=vae,
custom_pipeline="pipeline_animatediff_controlnet",
).to(device="cuda", dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
)
pipe.enable_vae_slicing()
def load_video(file_path: str):
images = []
if file_path.startswith(('http://', 'https://')):
# If the file_path is a URL
response = requests.get(file_path)
response.raise_for_status()
content = BytesIO(response.content)
vid = imageio.get_reader(content)
else:
# Assuming it's a local file path
vid = imageio.get_reader(file_path)
for frame in vid:
pil_image = Image.fromarray(frame)
images.append(pil_image)
return images
video = load_video("dance.gif")
# You need to install it using `pip install controlnet_aux`
from controlnet_aux.processor import Processor
p1 = Processor("openpose_full")
cn1 = [p1(frame) for frame in video]
p2 = Processor("canny")
cn2 = [p2(frame) for frame in video]
prompt = "astronaut in space, dancing"
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=512,
height=768,
conditioning_frames=[cn1, cn2],
num_inference_steps=20,
)
from diffusers.utils import export_to_gif
export_to_gif(result.frames[0], "result.gif")
```
### DemoFusion
This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).
@@ -3333,4 +3409,88 @@ images = pipe(
# Disable StyleAligned if you do not wish to use it anymore
pipe.disable_style_aligned()
```
```
### AnimateDiff Image-To-Video Pipeline
This pipeline adds experimental support for the image-to-video task using AnimateDiff. Refer to [this](https://github.com/huggingface/diffusers/pull/6328) PR for more examples and results.
```py
import torch
from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
from diffusers.utils import export_to_gif, load_image
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
image = load_image("snail.png")
output = pipe(
image=image,
prompt="A snail moving on the ground",
strength=0.8,
latent_interpolation_method="slerp", # can be lerp, slerp, or your own callback
)
frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
### IP Adapter Face ID
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
You need to install `insightface` and all its requirements to use this model.
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
You have to disable PEFT BACKEND in order to load weights.
```py
import diffusers
diffusers.utils.USE_PEFT_BACKEND = False
import torch
from diffusers.utils import load_image
import cv2
import numpy as np
from diffusers import DiffusionPipeline, AutoencoderKL, DDIMScheduler
from insightface.app import FaceAnalysis
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float16)
pipeline = DiffusionPipeline.from_pretrained(
"SG161222/Realistic_Vision_V4.0_noVAE",
torch_dtype=torch.float16,
scheduler=noise_scheduler,
vae=vae,
custom_pipeline="ip_adapter_face_id"
)
pipeline.load_ip_adapter_face_id("h94/IP-Adapter-FaceID", "ip-adapter-faceid_sd15.bin")
pipeline.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(42)
num_images=2
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png")
app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
faces = app.get(image)
image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
images = pipeline(
prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower",
image_embeds=image,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704,
generator=generator
).images
for i in range(num_images):
images[i].save(f"c{i}.png")
```
File diff suppressed because it is too large Load Diff
@@ -14,7 +14,7 @@
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -26,7 +26,7 @@ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unet_motion_model import MotionAdapter
from diffusers.models.unets.unet_motion_model import MotionAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import (
@@ -66,7 +66,7 @@ EXAMPLE_DOC_STRING = """
... custom_pipeline="pipeline_animatediff_controlnet",
... ).to(device="cuda", dtype=torch.float16)
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1, beta_schedule="linear",
... )
>>> pipe.enable_vae_slicing()
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
... height=768,
... conditioning_frames=conditioning_frames,
... num_inference_steps=12,
... ).frames[0]
... )
>>> from diffusers.utils import export_to_gif
>>> export_to_gif(result.frames[0], "result.gif")
@@ -151,7 +151,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
motion_adapter: MotionAdapter,
controlnet: Union[ControlNetModel, MultiControlNetModel],
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
@@ -166,6 +166,9 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
super().__init__()
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -488,6 +491,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
prompt,
height,
width,
num_frames,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
@@ -557,31 +561,21 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
or is_compiled
and isinstance(self.controlnet._orig_mod, ControlNetModel)
):
if isinstance(image, list):
for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
else:
self.check_image(image, prompt, prompt_embeds)
if not isinstance(image, list):
raise TypeError(f"For single controlnet, `image` must be of type `list` but got {type(image)}")
if len(image) != num_frames:
raise ValueError(f"Excepted image to have length {num_frames} but got {len(image)=}")
elif (
isinstance(self.controlnet, MultiControlNetModel)
or is_compiled
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")
# When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
)
for control_ in image:
for image_ in control_:
self.check_image(image_, prompt, prompt_embeds)
if not isinstance(image, list) or not isinstance(image[0], list):
raise TypeError(f"For multiple controlnets: `image` must be type list of lists but got {type(image)=}")
if len(image[0]) != num_frames:
raise ValueError(f"Expected length of image sublist as {num_frames} but got {len(image[0])=}")
if any(len(img) != len(image[0]) for img in image):
raise ValueError("All conditioning frame batches for multicontrolnet must be same size")
else:
assert False
@@ -913,6 +907,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
callback_steps=callback_steps,
negative_prompt=negative_prompt,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
@@ -1000,9 +995,7 @@ class AnimateDiffControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMix
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
cond_prepared_frames.append(prepared_frame)
conditioning_frames = cond_prepared_frames
else:
assert False
@@ -0,0 +1,989 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unet_motion_model import MotionAdapter
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import MotionAdapter, DiffusionPipeline, DDIMScheduler
>>> from diffusers.utils import export_to_gif, load_image
>>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
>>> pipe = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter, custom_pipeline="pipeline_animatediff_img2video").to("cuda")
>>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
>>> image = load_image("snail.png")
>>> output = pipe(image=image, prompt="A snail moving on the ground", strength=0.8, latent_interpolation_method="slerp")
>>> frames = output.frames[0]
>>> export_to_gif(frames, "animation.gif")
```
"""
def lerp(
v0: torch.Tensor,
v1: torch.Tensor,
t: Union[float, torch.Tensor],
) -> torch.Tensor:
r"""
Linear Interpolation between two tensors.
Args:
v0 (`torch.Tensor`): First tensor.
v1 (`torch.Tensor`): Second tensor.
t: (`float` or `torch.Tensor`): Interpolation factor.
"""
t_is_float = False
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
if isinstance(t, torch.Tensor):
t = t.cpu().numpy()
else:
t_is_float = True
t = np.array([t], dtype=v0.dtype)
t = t[..., None]
v0 = v0[None, ...]
v1 = v1[None, ...]
v2 = (1 - t) * v0 + t * v1
if t_is_float and v0.ndim > 1:
assert v2.shape[0] == 1
v2 = np.squeeze(v2, axis=0)
v2 = torch.from_numpy(v2).to(input_device)
return v2
def slerp(
v0: torch.Tensor,
v1: torch.Tensor,
t: Union[float, torch.Tensor],
DOT_THRESHOLD: float = 0.9995,
) -> torch.Tensor:
r"""
Spherical Linear Interpolation between two tensors.
Args:
v0 (`torch.Tensor`): First tensor.
v1 (`torch.Tensor`): Second tensor.
t: (`float` or `torch.Tensor`): Interpolation factor.
DOT_THRESHOLD (`float`):
Dot product threshold exceeding which linear interpolation will be used
because input tensors are close to parallel.
"""
t_is_float = False
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
if isinstance(t, torch.Tensor):
t = t.cpu().numpy()
else:
t_is_float = True
t = np.array([t], dtype=v0.dtype)
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
# v0 and v1 are close to parallel, so use linear interpolation instead
v2 = lerp(v0, v1, t)
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
s0 = s0[..., None]
s1 = s1[..., None]
v0 = v0[None, ...]
v1 = v1[None, ...]
v2 = s0 * v0 + s1 * v1
if t_is_float and v0.ndim > 1:
assert v2.shape[0] == 1
v2 = np.squeeze(v2, axis=0)
v2 = torch.from_numpy(v2).to(input_device)
return v2
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
@dataclass
class AnimateDiffImgToVideoPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]
class AnimateDiffImgToVideoPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer (`CLIPTokenizer`):
A [`~transformers.CLIPTokenizer`] to tokenize text.
unet ([`UNet2DConditionModel`]):
A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents.
motion_adapter ([`MotionAdapter`]):
A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
motion_adapter: MotionAdapter,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
):
super().__init__()
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
motion_adapter=motion_adapter,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
batch_size, channels, num_frames, height, width = latents.shape
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
image = self.vae.decode(latents).sample
video = (
image[None, :]
.reshape(
(
batch_size,
num_frames,
-1,
)
+ image.shape[2:]
)
.permute(0, 2, 1, 3, 4)
)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
video = video.float()
return video
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
latent_interpolation_method=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if latent_interpolation_method is not None:
if latent_interpolation_method not in ["lerp", "slerp"] and not isinstance(
latent_interpolation_method, FunctionType
):
raise ValueError(
"`latent_interpolation_method` must be one of `lerp`, `slerp` or a Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]"
)
def prepare_latents(
self,
image,
strength,
batch_size,
num_channels_latents,
num_frames,
height,
width,
dtype,
device,
generator,
latents=None,
latent_interpolation_method="slerp",
):
shape = (
batch_size,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
image = image.to(device=device, dtype=dtype)
if image.shape[1] == 4:
latents = image
else:
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
self.vae.to(dtype=torch.float32)
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
init_latents = init_latents.to(dtype)
init_latents = self.vae.config.scaling_factor * init_latents
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = latents * self.scheduler.init_noise_sigma
if latent_interpolation_method == "lerp":
def latent_cls(v0, v1, index):
return lerp(v0, v1, index / num_frames * (1 - strength))
elif latent_interpolation_method == "slerp":
def latent_cls(v0, v1, index):
return slerp(v0, v1, index / num_frames * (1 - strength))
else:
latent_cls = latent_interpolation_method
for i in range(num_frames):
latents[:, :, i, :, :] = latent_cls(latents[:, :, i, :, :], init_latents, i)
else:
if shape != latents.shape:
# [B, C, F, H, W]
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
latents = latents.to(device, dtype=dtype)
return latents
@torch.no_grad()
def __call__(
self,
image: PipelineImageInput,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: int = 16,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 7.5,
strength: float = 0.8,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
latent_interpolation_method: Union[str, Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]] = "slerp",
):
r"""
The call function to the pipeline for generation.
Args:
image (`PipelineImageInput`):
The input image to condition the generation on.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated video.
num_frames (`int`, *optional*, defaults to 16):
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
amounts to 2 seconds of video.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
expense of slower inference.
strength (`float`, *optional*, defaults to 0.8):
Higher strength leads to more differences between original image and generated video.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`. Latents should be of shape
`(batch_size, num_channel, num_frames, height, width)`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AnimateDiffImgToVideoPipelineOutput`] instead
of a plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
latent_interpolation_method (`str` or `Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]]`, *optional*):
Must be one of "lerp", "slerp" or a callable that takes in a random noisy latent, image latent and a frame index
as input and returns an initial latent for sampling.
Examples:
Returns:
[`AnimateDiffImgToVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`AnimateDiffImgToVideoPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
callback_steps=callback_steps,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
latent_interpolation_method=latent_interpolation_method,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
)
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Preprocess image
image = self.image_processor.preprocess(image, height=height, width=width)
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
image=image,
strength=strength,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
num_frames=num_frames,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
latent_interpolation_method=latent_interpolation_method,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
return AnimateDiffImgToVideoPipelineOutput(frames=latents)
# 10. Post-processing
video_tensor = self.decode_latents(latents)
if output_type == "pt":
video = video_tensor
else:
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# 11. Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return AnimateDiffImgToVideoPipelineOutput(frames=video)
@@ -8,7 +8,7 @@ import torch
from diffusers import StableDiffusionControlNetPipeline
from diffusers.models import ControlNetModel
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import logging
@@ -7,7 +7,7 @@ import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from diffusers.utils import PIL_INTERPOLATION, logging
@@ -8,7 +8,7 @@ import torch
from diffusers import StableDiffusionXLPipeline
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.unet_2d_blocks import (
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
+44 -6
View File
@@ -35,7 +35,7 @@ from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
@@ -54,7 +54,13 @@ from diffusers import (
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
from diffusers.utils import (
check_min_version,
convert_state_dict_to_diffusers,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -892,10 +898,33 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
)
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
if args.train_text_encoder:
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
# Make sure the trainable params are in float32. This is again needed since the base models
# are in `weight_dtype`. More details:
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
if args.mixed_precision == "fp16":
models = [unet_]
if args.train_text_encoder:
models.append(text_encoder_)
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -910,6 +939,15 @@ def main(args):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [unet]
if args.train_text_encoder:
models.append(text_encoder)
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(models, dtype=torch.float32)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
@@ -55,6 +55,9 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0")
@@ -67,6 +70,57 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
def log_validation(
pipeline,
args,
accelerator,
generator,
global_step,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
val_save_dir = os.path.join(args.output_dir, "validation_images")
if not os.path.exists(val_save_dir):
os.makedirs(val_save_dir)
original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
edited_images = []
# Run inference
for val_img_idx in range(args.num_validation_images):
a_val_img = pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
edited_images.append(a_val_img)
# Save validation images
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
logger_name = "test" if is_final_validation else "validation"
tracker.log({logger_name: wandb_table})
def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
@@ -447,11 +501,6 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -1109,13 +1158,8 @@ def main():
progress_bar.set_postfix(**logs)
### BEGIN: Perform validation every `validation_epochs` steps
if global_step % args.validation_steps == 0 or global_step == 1:
if global_step % args.validation_steps == 0:
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
@@ -1135,44 +1179,16 @@ def main():
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
# Save validation images
val_save_dir = os.path.join(args.output_dir, "validation_images")
if not os.path.exists(val_save_dir):
os.makedirs(val_save_dir)
log_validation(
pipeline,
args,
accelerator,
generator,
global_step,
is_final_validation=False,
)
original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
with torch.autocast(
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
):
edited_images = []
for val_img_idx in range(args.num_validation_images):
a_val_img = pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
edited_images.append(a_val_img)
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"validation": wandb_table})
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
@@ -1187,7 +1203,6 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
@@ -1198,10 +1213,11 @@ def main():
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
vae=vae,
unet=unet,
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
@@ -1212,30 +1228,15 @@ def main():
ignore_patterns=["step_*", "epoch_*"],
)
if args.validation_prompt is not None:
edited_images = []
pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device).replace(":0", "")):
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"test": wandb_table})
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
log_validation(
pipeline,
args,
accelerator,
generator,
global_step,
is_final_validation=True,
)
accelerator.end_training()
@@ -26,7 +26,7 @@ from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProc
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.lora import LoRACompatibleConv
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import (
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
@@ -36,7 +36,7 @@ from diffusers.models.unet_2d_blocks import (
UpBlock2D,
Upsample2D,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.utils import BaseOutput, logging
@@ -740,6 +740,10 @@ def main(args):
# Resize.
combined_im = train_resize(combined_im)
# Flipping.
if not args.no_flip and random.random() < 0.5:
combined_im = train_flip(combined_im)
# Cropping.
if not args.random_crop:
y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
@@ -749,11 +753,6 @@ def main(args):
y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
combined_im = crop(combined_im, y1, x1, h, w)
# Flipping.
if random.random() < 0.5:
x1 = combined_im.shape[2] - x1
combined_im = train_flip(combined_im)
crop_top_left = (y1, x1)
crop_top_lefts.append(crop_top_left)
combined_im = normalize(combined_im)
@@ -6,4 +6,4 @@ torch==2.0.1
torchvision>=0.16
ftfy==6.1.1
tensorboard==2.14.0
Jinja2==3.1.2
Jinja2==3.1.3
+11 -4
View File
@@ -50,6 +50,7 @@ from diffusers import (
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
MAX_SEQ_LENGTH = 77
@@ -926,6 +927,11 @@ def main(args):
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
@@ -935,9 +941,9 @@ def main(args):
" doing mixed precision training, copy of the weights should still be float32."
)
if accelerator.unwrap_model(t2iadapter).dtype != torch.float32:
if unwrap_model(t2iadapter).dtype != torch.float32:
raise ValueError(
f"Controlnet loaded as datatype {accelerator.unwrap_model(t2iadapter).dtype}. {low_precision_error_string}"
f"Controlnet loaded as datatype {unwrap_model(t2iadapter).dtype}. {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
@@ -1198,7 +1204,8 @@ def main(args):
encoder_hidden_states=batch["prompt_ids"],
added_cond_kwargs=batch["unet_added_conditions"],
down_block_additional_residuals=down_block_additional_residuals,
).sample
return_dict=False,
)[0]
# Denoise the latents
denoised_latents = model_pred * (-sigmas) + noisy_latents
@@ -1279,7 +1286,7 @@ def main(args):
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
t2iadapter = accelerator.unwrap_model(t2iadapter)
t2iadapter = unwrap_model(t2iadapter)
t2iadapter.save_pretrained(args.output_dir)
if args.push_to_hub:
+60
View File
@@ -183,6 +183,66 @@ The above command will also run inference as fine-tuning progresses and log the
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
### Using DeepSpeed
Using DeepSpeed one can reduce the consumption of GPU memory, enabling the training of models on GPUs with smaller memory sizes. DeepSpeed is capable of offloading model parameters to the machine's memory, or it can distribute parameters, gradients, and optimizer states across multiple GPUs. This allows for the training of larger models under the same hardware configuration.
First, you need to use the `accelerate config` command to choose to use DeepSpeed, or manually use the accelerate config file to set up DeepSpeed.
Here is an example of a config file for using DeepSpeed. For more detailed explanations of the configuration, you can refer to this [link](https://huggingface.co/docs/accelerate/usage_guides/deepspeed).
```yaml
compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
You need to save the mentioned configuration as an `accelerate_config.yaml` file. Then, you need to input the path of your `accelerate_config.yaml` file into the `ACCELERATE_CONFIG_FILE` parameter. This way you can use DeepSpeed to train your SDXL model in LoRA. Additionally, you can use DeepSpeed to train other SD models in this way.
```shell
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
export ACCELERATE_CONFIG_FILE="your accelerate_config.yaml"
accelerate launch --config_file $ACCELERATE_CONFIG_FILE train_text_to_image_lora_sdxl.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--pretrained_vae_model_name_or_path=$VAE_NAME \
--dataset_name=$DATASET_NAME --caption_column="text" \
--resolution=1024 \
--train_batch_size=1 \
--num_train_epochs=2 \
--checkpointing_steps=2 \
--learning_rate=1e-04 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--mixed_precision="fp16" \
--max_train_steps=20 \
--validation_epochs=20 \
--seed=1234 \
--output_dir="sd-pokemon-model-lora-sdxl" \
--validation_prompt="cute dragon creature"
```
### Finetuning the text encoder and UNet
The script also allows you to finetune the `text_encoder` along with the `unet`.
@@ -652,13 +652,13 @@ def main(args):
text_encoder_two_lora_layers_to_save = None
for model in models:
if isinstance(model, type(unwrap_model(unet))):
if isinstance(unwrap_model(model), type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(unwrap_model(text_encoder_one))):
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(unwrap_model(text_encoder_two))):
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
@@ -666,7 +666,8 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()
StableDiffusionXLPipeline.save_lora_weights(
output_dir,
@@ -836,6 +837,9 @@ def main(args):
for image in images:
original_sizes.append((image.height, image.width))
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
@@ -843,10 +847,6 @@ def main(args):
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
# flip
x1 = image.width - x1
image = train_flip(image)
crop_top_left = (y1, x1)
crop_top_lefts.append(crop_top_left)
image = train_transforms(image)
@@ -839,6 +839,9 @@ def main(args):
for image in images:
original_sizes.append((image.height, image.width))
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
@@ -846,10 +849,6 @@ def main(args):
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
# flip
x1 = image.width - x1
image = train_flip(image)
crop_top_left = (y1, x1)
crop_top_lefts.append(crop_top_left)
image = train_transforms(image)
+7 -7
View File
@@ -1,13 +1,13 @@
import argparse
import OmegaConf
import torch
import yaml
from diffusers import DDIMScheduler, LDMPipeline, UNetLDMModel, VQModel
def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path)
config = yaml.safe_load(config_path)
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
keys = list(state_dict.keys())
@@ -25,8 +25,8 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = state_dict[key]
vqvae_init_args = config.model.params.first_stage_config.params
unet_init_args = config.model.params.unet_config.params
vqvae_init_args = config["model"]["params"]["first_stage_config"]["params"]
unet_init_args = config["model"]["params"]["unet_config"]["params"]
vqvae = VQModel(**vqvae_init_args).eval()
vqvae.load_state_dict(first_stage_dict)
@@ -35,10 +35,10 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
unet.load_state_dict(unet_state_dict)
noise_scheduler = DDIMScheduler(
timesteps=config.model.params.timesteps,
timesteps=config["model"]["params"]["timesteps"],
beta_schedule="scaled_linear",
beta_start=config.model.params.linear_start,
beta_end=config.model.params.linear_end,
beta_start=config["model"]["params"]["linear_start"],
beta_end=config["model"]["params"]["linear_end"],
clip_sample=False,
)
+1 -1
View File
@@ -10,7 +10,7 @@ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import VQModel
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.uvit_2d import UVit2DModel
from diffusers.models.unets.uvit_2d import UVit2DModel
from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline
from diffusers.schedulers import AmusedScheduler
+1 -1
View File
@@ -14,7 +14,7 @@ from tqdm import tqdm
from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
from diffusers.models.autoencoders.vae import Encoder
from diffusers.models.embeddings import TimestepEmbedding
from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
args = ArgumentParser()
+23 -29
View File
@@ -2,6 +2,7 @@ import argparse
import re
import torch
import yaml
from transformers import (
CLIPProcessor,
CLIPTextModel,
@@ -28,8 +29,6 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
textenc_conversion_map,
textenc_pattern,
)
from diffusers.utils import is_omegaconf_available
from diffusers.utils.import_utils import BACKENDS_MAPPING
def convert_open_clip_checkpoint(checkpoint):
@@ -370,52 +369,52 @@ def convert_gligen_unet_checkpoint(checkpoint, config, path=None, extract_ema=Fa
def create_vae_config(original_config, image_size: int):
vae_params = original_config.autoencoder.params.ddconfig
_ = original_config.autoencoder.params.embed_dim
vae_params = original_config["autoencoder"]["params"]["ddconfig"]
_ = original_config["autoencoder"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = {
"sample_size": image_size,
"in_channels": vae_params.in_channels,
"out_channels": vae_params.out_ch,
"in_channels": vae_params["in_channels"],
"out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels,
"layers_per_block": vae_params.num_res_blocks,
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
}
return config
def create_unet_config(original_config, image_size: int, attention_type):
unet_params = original_config.model.params
vae_params = original_config.autoencoder.params.ddconfig
unet_params = original_config["model"]["params"]
vae_params = original_config["autoencoder"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
if head_dim is None:
@@ -423,11 +422,11 @@ def create_unet_config(original_config, image_size: int, attention_type):
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"in_channels": unet_params["in_channels"],
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"cross_attention_dim": unet_params.context_dim,
"layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": unet_params["context_dim"],
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"attention_type": attention_type,
@@ -445,11 +444,6 @@ def convert_gligen_to_diffusers(
num_in_channels: int = None,
device: str = None,
):
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
@@ -461,14 +455,14 @@ def convert_gligen_to_diffusers(
else:
print("global_step key not found in model")
original_config = OmegaConf.load(original_config_file)
original_config = yaml.safe_load(original_config_file)
if num_in_channels is not None:
original_config["model"]["params"]["in_channels"] = num_in_channels
num_train_timesteps = original_config.diffusion.params.timesteps
beta_start = original_config.diffusion.params.linear_start
beta_end = original_config.diffusion.params.linear_end
num_train_timesteps = original_config["diffusion"]["params"]["timesteps"]
beta_start = original_config["diffusion"]["params"]["linear_start"]
beta_end = original_config["diffusion"]["params"]["linear_end"]
scheduler = DDIMScheduler(
beta_end=beta_end,
+46 -53
View File
@@ -4,6 +4,7 @@ import os
import numpy as np
import torch
import yaml
from torch.nn import functional as F
from transformers import CLIPConfig, CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer
@@ -11,14 +12,6 @@ from diffusers import DDPMScheduler, IFPipeline, IFSuperResolutionPipeline, UNet
from diffusers.pipelines.deepfloyd_if.safety_checker import IFSafetyChecker
try:
from omegaconf import OmegaConf
except ImportError:
raise ImportError(
"OmegaConf is required to convert the IF checkpoints. Please install it with `pip install" " OmegaConf`."
)
def parse_args():
parser = argparse.ArgumentParser()
@@ -143,8 +136,8 @@ def convert_super_res_pipeline(tokenizer, text_encoder, feature_extractor, safet
def get_stage_1_unet(unet_config, unet_checkpoint_path):
original_unet_config = OmegaConf.load(unet_config)
original_unet_config = original_unet_config.params
original_unet_config = yaml.safe_load(unet_config)
original_unet_config = original_unet_config["params"]
unet_diffusers_config = create_unet_diffusers_config(original_unet_config)
@@ -215,11 +208,11 @@ def convert_safety_checker(p_head_path, w_head_path):
def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
attention_resolutions = parse_list(original_unet_config.attention_resolutions)
attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions]
attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
channel_mult = parse_list(original_unet_config.channel_mult)
block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult]
channel_mult = parse_list(original_unet_config["channel_mult"])
block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
down_block_types = []
resolution = 1
@@ -227,7 +220,7 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
for i in range(len(block_out_channels)):
if resolution in attention_resolutions:
block_type = "SimpleCrossAttnDownBlock2D"
elif original_unet_config.resblock_updown:
elif original_unet_config["resblock_updown"]:
block_type = "ResnetDownsampleBlock2D"
else:
block_type = "DownBlock2D"
@@ -241,17 +234,17 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
for i in range(len(block_out_channels)):
if resolution in attention_resolutions:
block_type = "SimpleCrossAttnUpBlock2D"
elif original_unet_config.resblock_updown:
elif original_unet_config["resblock_updown"]:
block_type = "ResnetUpsampleBlock2D"
else:
block_type = "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
head_dim = original_unet_config.num_head_channels
head_dim = original_unet_config["num_head_channels"]
use_linear_projection = (
original_unet_config.use_linear_in_transformer
original_unet_config["use_linear_in_transformer"]
if "use_linear_in_transformer" in original_unet_config
else False
)
@@ -264,27 +257,27 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
if class_embed_type is None:
if "num_classes" in original_unet_config:
if original_unet_config.num_classes == "sequential":
if original_unet_config["num_classes"] == "sequential":
class_embed_type = "projection"
assert "adm_in_channels" in original_unet_config
projection_class_embeddings_input_dim = original_unet_config.adm_in_channels
projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
else:
raise NotImplementedError(
f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}"
f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
)
config = {
"sample_size": original_unet_config.image_size,
"in_channels": original_unet_config.in_channels,
"sample_size": original_unet_config["image_size"],
"in_channels": original_unet_config["in_channels"],
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": original_unet_config.num_res_blocks,
"cross_attention_dim": original_unet_config.encoder_channels,
"layers_per_block": original_unet_config["num_res_blocks"],
"cross_attention_dim": original_unet_config["encoder_channels"],
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"out_channels": original_unet_config.out_channels,
"out_channels": original_unet_config["out_channels"],
"up_block_types": tuple(up_block_types),
"upcast_attention": False, # TODO: guessing
"cross_attention_norm": "group_norm",
@@ -293,11 +286,11 @@ def create_unet_diffusers_config(original_unet_config, class_embed_type=None):
"act_fn": "gelu",
}
if original_unet_config.use_scale_shift_norm:
if original_unet_config["use_scale_shift_norm"]:
config["resnet_time_scale_shift"] = "scale_shift"
if "encoder_dim" in original_unet_config:
config["encoder_hid_dim"] = original_unet_config.encoder_dim
config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
return config
@@ -725,15 +718,15 @@ def parse_list(value):
def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_size=None):
orig_path = unet_checkpoint_path
original_unet_config = OmegaConf.load(os.path.join(orig_path, "config.yml"))
original_unet_config = original_unet_config.params
original_unet_config = yaml.safe_load(os.path.join(orig_path, "config.yml"))
original_unet_config = original_unet_config["params"]
unet_diffusers_config = superres_create_unet_diffusers_config(original_unet_config)
unet_diffusers_config["time_embedding_dim"] = original_unet_config.model_channels * int(
original_unet_config.channel_mult.split(",")[-1]
unet_diffusers_config["time_embedding_dim"] = original_unet_config["model_channels"] * int(
original_unet_config["channel_mult"].split(",")[-1]
)
if original_unet_config.encoder_dim != original_unet_config.encoder_channels:
unet_diffusers_config["encoder_hid_dim"] = original_unet_config.encoder_dim
if original_unet_config["encoder_dim"] != original_unet_config["encoder_channels"]:
unet_diffusers_config["encoder_hid_dim"] = original_unet_config["encoder_dim"]
unet_diffusers_config["class_embed_type"] = "timestep"
unet_diffusers_config["addition_embed_type"] = "text"
@@ -742,16 +735,16 @@ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_siz
unet_diffusers_config["resnet_out_scale_factor"] = 1 / 0.7071
unet_diffusers_config["mid_block_scale_factor"] = 1 / 0.7071
unet_diffusers_config["only_cross_attention"] = (
bool(original_unet_config.disable_self_attentions)
bool(original_unet_config["disable_self_attentions"])
if (
"disable_self_attentions" in original_unet_config
and isinstance(original_unet_config.disable_self_attentions, int)
and isinstance(original_unet_config["disable_self_attentions"], int)
)
else True
)
if sample_size is None:
unet_diffusers_config["sample_size"] = original_unet_config.image_size
unet_diffusers_config["sample_size"] = original_unet_config["image_size"]
else:
# The second upscaler unet's sample size is incorrectly specified
# in the config and is instead hardcoded in source
@@ -783,11 +776,11 @@ def get_super_res_unet(unet_checkpoint_path, verify_param_count=True, sample_siz
def superres_create_unet_diffusers_config(original_unet_config):
attention_resolutions = parse_list(original_unet_config.attention_resolutions)
attention_resolutions = [original_unet_config.image_size // int(res) for res in attention_resolutions]
attention_resolutions = parse_list(original_unet_config["attention_resolutions"])
attention_resolutions = [original_unet_config["image_size"] // int(res) for res in attention_resolutions]
channel_mult = parse_list(original_unet_config.channel_mult)
block_out_channels = [original_unet_config.model_channels * mult for mult in channel_mult]
channel_mult = parse_list(original_unet_config["channel_mult"])
block_out_channels = [original_unet_config["model_channels"] * mult for mult in channel_mult]
down_block_types = []
resolution = 1
@@ -795,7 +788,7 @@ def superres_create_unet_diffusers_config(original_unet_config):
for i in range(len(block_out_channels)):
if resolution in attention_resolutions:
block_type = "SimpleCrossAttnDownBlock2D"
elif original_unet_config.resblock_updown:
elif original_unet_config["resblock_updown"]:
block_type = "ResnetDownsampleBlock2D"
else:
block_type = "DownBlock2D"
@@ -809,16 +802,16 @@ def superres_create_unet_diffusers_config(original_unet_config):
for i in range(len(block_out_channels)):
if resolution in attention_resolutions:
block_type = "SimpleCrossAttnUpBlock2D"
elif original_unet_config.resblock_updown:
elif original_unet_config["resblock_updown"]:
block_type = "ResnetUpsampleBlock2D"
else:
block_type = "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
head_dim = original_unet_config.num_head_channels
head_dim = original_unet_config["num_head_channels"]
use_linear_projection = (
original_unet_config.use_linear_in_transformer
original_unet_config["use_linear_in_transformer"]
if "use_linear_in_transformer" in original_unet_config
else False
)
@@ -831,26 +824,26 @@ def superres_create_unet_diffusers_config(original_unet_config):
projection_class_embeddings_input_dim = None
if "num_classes" in original_unet_config:
if original_unet_config.num_classes == "sequential":
if original_unet_config["num_classes"] == "sequential":
class_embed_type = "projection"
assert "adm_in_channels" in original_unet_config
projection_class_embeddings_input_dim = original_unet_config.adm_in_channels
projection_class_embeddings_input_dim = original_unet_config["adm_in_channels"]
else:
raise NotImplementedError(
f"Unknown conditional unet num_classes config: {original_unet_config.num_classes}"
f"Unknown conditional unet num_classes config: {original_unet_config['num_classes']}"
)
config = {
"in_channels": original_unet_config.in_channels,
"in_channels": original_unet_config["in_channels"],
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": tuple(original_unet_config.num_res_blocks),
"cross_attention_dim": original_unet_config.encoder_channels,
"layers_per_block": tuple(original_unet_config["num_res_blocks"]),
"cross_attention_dim": original_unet_config["encoder_channels"],
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"out_channels": original_unet_config.out_channels,
"out_channels": original_unet_config["out_channels"],
"up_block_types": tuple(up_block_types),
"upcast_attention": False, # TODO: guessing
"cross_attention_norm": "group_norm",
@@ -858,7 +851,7 @@ def superres_create_unet_diffusers_config(original_unet_config):
"act_fn": "gelu",
}
if original_unet_config.use_scale_shift_norm:
if original_unet_config["use_scale_shift_norm"]:
config["resnet_time_scale_shift"] = "scale_shift"
return config
@@ -19,6 +19,7 @@ import re
from typing import List, Union
import torch
import yaml
from transformers import (
AutoFeatureExtractor,
AutoTokenizer,
@@ -45,7 +46,7 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import is_omegaconf_available, is_safetensors_available
from diffusers.utils import is_safetensors_available
from diffusers.utils.import_utils import BACKENDS_MAPPING
@@ -212,41 +213,41 @@ def create_unet_diffusers_config(original_config, image_size: int):
"""
Creates a UNet config for diffusers based on the config of the original AudioLDM2 model.
"""
unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
unet_params = original_config["model"]["params"]["unet_config"]["params"]
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
cross_attention_dim = list(unet_params.context_dim) if "context_dim" in unet_params else block_out_channels
cross_attention_dim = list(unet_params["context_dim"]) if "context_dim" in unet_params else block_out_channels
if len(cross_attention_dim) > 1:
# require two or more cross-attention layers per-block, each of different dimension
cross_attention_dim = [cross_attention_dim for _ in range(len(block_out_channels))]
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"out_channels": unet_params.out_channels,
"in_channels": unet_params["in_channels"],
"out_channels": unet_params["out_channels"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"transformer_layers_per_block": unet_params.transformer_depth,
"layers_per_block": unet_params["num_res_blocks"],
"transformer_layers_per_block": unet_params["transformer_depth"],
"cross_attention_dim": tuple(cross_attention_dim),
}
@@ -259,24 +260,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
Creates a VAE config for diffusers based on the config of the original AudioLDM2 model. Compared to the original
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
"""
vae_params = original_config.model.params.first_stage_config.params.ddconfig
_ = original_config.model.params.first_stage_config.params.embed_dim
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
config = {
"sample_size": image_size,
"in_channels": vae_params.in_channels,
"out_channels": vae_params.out_ch,
"in_channels": vae_params["in_channels"],
"out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels,
"layers_per_block": vae_params.num_res_blocks,
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": float(scaling_factor),
}
return config
@@ -285,9 +286,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
def create_diffusers_schedular(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
num_train_timesteps=original_config["model"]["params"]["timesteps"],
beta_start=original_config["model"]["params"]["linear_start"],
beta_end=original_config["model"]["params"]["linear_end"],
beta_schedule="scaled_linear",
)
return schedular
@@ -692,17 +693,17 @@ def create_transformers_vocoder_config(original_config):
"""
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
"""
vocoder_params = original_config.model.params.vocoder_config.params
vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
config = {
"model_in_dim": vocoder_params.num_mels,
"sampling_rate": vocoder_params.sampling_rate,
"upsample_initial_channel": vocoder_params.upsample_initial_channel,
"upsample_rates": list(vocoder_params.upsample_rates),
"upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes),
"resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes),
"model_in_dim": vocoder_params["num_mels"],
"sampling_rate": vocoder_params["sampling_rate"],
"upsample_initial_channel": vocoder_params["upsample_initial_channel"],
"upsample_rates": list(vocoder_params["upsample_rates"]),
"upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
"resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
"resblock_dilation_sizes": [
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes
list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
],
"normalize_before": False,
}
@@ -876,11 +877,6 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
return: An AudioLDM2Pipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if from_safetensors:
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
@@ -903,9 +899,8 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
if original_config_file is None:
original_config = DEFAULT_CONFIG
original_config = OmegaConf.create(original_config)
else:
original_config = OmegaConf.load(original_config_file)
original_config = yaml.safe_load(original_config_file)
if image_size is not None:
original_config["model"]["params"]["unet_config"]["params"]["image_size"] = image_size
@@ -926,9 +921,9 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
if prediction_type is None:
prediction_type = "epsilon"
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
num_train_timesteps = original_config["model"]["params"]["timesteps"]
beta_start = original_config["model"]["params"]["linear_start"]
beta_end = original_config["model"]["params"]["linear_end"]
scheduler = DDIMScheduler(
beta_end=beta_end,
@@ -1026,9 +1021,9 @@ def load_pipeline_from_original_AudioLDM2_ckpt(
# Convert the GPT2 encoder model: AudioLDM2 uses the same configuration as the original GPT2 base model
gpt2_config = GPT2Config.from_pretrained("gpt2")
gpt2_model = GPT2Model(gpt2_config)
gpt2_model.config.max_new_tokens = (
original_config.model.params.cond_stage_config.crossattn_audiomae_generated.params.sequence_gen_length
)
gpt2_model.config.max_new_tokens = original_config["model"]["params"]["cond_stage_config"][
"crossattn_audiomae_generated"
]["params"]["sequence_gen_length"]
converted_gpt2_checkpoint = extract_sub_model(checkpoint, key_prefix="cond_stage_models.0.model.")
gpt2_model.load_state_dict(converted_gpt2_checkpoint)
@@ -18,6 +18,7 @@ import argparse
import re
import torch
import yaml
from transformers import (
AutoTokenizer,
ClapTextConfig,
@@ -38,8 +39,6 @@ from diffusers import (
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import is_omegaconf_available
from diffusers.utils.import_utils import BACKENDS_MAPPING
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
@@ -215,45 +214,45 @@ def create_unet_diffusers_config(original_config, image_size: int):
"""
Creates a UNet config for diffusers based on the config of the original AudioLDM model.
"""
unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
unet_params = original_config["model"]["params"]["unet_config"]["params"]
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
cross_attention_dim = (
unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels
unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels
)
class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
projection_class_embeddings_input_dim = (
unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None
unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None
)
class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None
class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"out_channels": unet_params.out_channels,
"in_channels": unet_params["in_channels"],
"out_channels": unet_params["out_channels"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": cross_attention_dim,
"class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
@@ -269,24 +268,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
Creates a VAE config for diffusers based on the config of the original AudioLDM model. Compared to the original
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
"""
vae_params = original_config.model.params.first_stage_config.params.ddconfig
_ = original_config.model.params.first_stage_config.params.embed_dim
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
config = {
"sample_size": image_size,
"in_channels": vae_params.in_channels,
"out_channels": vae_params.out_ch,
"in_channels": vae_params["in_channels"],
"out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels,
"layers_per_block": vae_params.num_res_blocks,
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": float(scaling_factor),
}
return config
@@ -295,9 +294,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
def create_diffusers_schedular(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
num_train_timesteps=original_config["model"]["params"]["timesteps"],
beta_start=original_config["model"]["params"]["linear_start"],
beta_end=original_config["model"]["params"]["linear_end"],
beta_schedule="scaled_linear",
)
return schedular
@@ -668,17 +667,17 @@ def create_transformers_vocoder_config(original_config):
"""
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
"""
vocoder_params = original_config.model.params.vocoder_config.params
vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
config = {
"model_in_dim": vocoder_params.num_mels,
"sampling_rate": vocoder_params.sampling_rate,
"upsample_initial_channel": vocoder_params.upsample_initial_channel,
"upsample_rates": list(vocoder_params.upsample_rates),
"upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes),
"resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes),
"model_in_dim": vocoder_params["num_mels"],
"sampling_rate": vocoder_params["sampling_rate"],
"upsample_initial_channel": vocoder_params["upsample_initial_channel"],
"upsample_rates": list(vocoder_params["upsample_rates"]),
"upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
"resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
"resblock_dilation_sizes": [
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes
list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
],
"normalize_before": False,
}
@@ -818,11 +817,6 @@ def load_pipeline_from_original_audioldm_ckpt(
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if from_safetensors:
from safetensors import safe_open
@@ -842,9 +836,8 @@ def load_pipeline_from_original_audioldm_ckpt(
if original_config_file is None:
original_config = DEFAULT_CONFIG
original_config = OmegaConf.create(original_config)
else:
original_config = OmegaConf.load(original_config_file)
original_config = yaml.safe_load(original_config_file)
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
@@ -868,9 +861,9 @@ def load_pipeline_from_original_audioldm_ckpt(
if image_size is None:
image_size = 512
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
num_train_timesteps = original_config["model"]["params"]["timesteps"]
beta_start = original_config["model"]["params"]["linear_start"]
beta_end = original_config["model"]["params"]["linear_end"]
scheduler = DDIMScheduler(
beta_end=beta_end,
@@ -18,6 +18,7 @@ import argparse
import re
import torch
import yaml
from transformers import (
AutoFeatureExtractor,
AutoTokenizer,
@@ -39,8 +40,6 @@ from diffusers import (
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import is_omegaconf_available
from diffusers.utils.import_utils import BACKENDS_MAPPING
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
@@ -212,45 +211,45 @@ def create_unet_diffusers_config(original_config, image_size: int):
"""
Creates a UNet config for diffusers based on the config of the original MusicLDM model.
"""
unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
unet_params = original_config["model"]["params"]["unet_config"]["params"]
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
cross_attention_dim = (
unet_params.cross_attention_dim if "cross_attention_dim" in unet_params else block_out_channels
unet_params["cross_attention_dim"] if "cross_attention_dim" in unet_params else block_out_channels
)
class_embed_type = "simple_projection" if "extra_film_condition_dim" in unet_params else None
projection_class_embeddings_input_dim = (
unet_params.extra_film_condition_dim if "extra_film_condition_dim" in unet_params else None
unet_params["extra_film_condition_dim"] if "extra_film_condition_dim" in unet_params else None
)
class_embeddings_concat = unet_params.extra_film_use_concat if "extra_film_use_concat" in unet_params else None
class_embeddings_concat = unet_params["extra_film_use_concat"] if "extra_film_use_concat" in unet_params else None
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"out_channels": unet_params.out_channels,
"in_channels": unet_params["in_channels"],
"out_channels": unet_params["out_channels"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": cross_attention_dim,
"class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
@@ -266,24 +265,24 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
Creates a VAE config for diffusers based on the config of the original MusicLDM model. Compared to the original
Stable Diffusion conversion, this function passes a *learnt* VAE scaling factor to the diffusers VAE.
"""
vae_params = original_config.model.params.first_stage_config.params.ddconfig
_ = original_config.model.params.first_stage_config.params.embed_dim
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config.model.params else 0.18215
scaling_factor = checkpoint["scale_factor"] if "scale_by_std" in original_config["model"]["params"] else 0.18215
config = {
"sample_size": image_size,
"in_channels": vae_params.in_channels,
"out_channels": vae_params.out_ch,
"in_channels": vae_params["in_channels"],
"out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels,
"layers_per_block": vae_params.num_res_blocks,
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
"scaling_factor": float(scaling_factor),
}
return config
@@ -292,9 +291,9 @@ def create_vae_diffusers_config(original_config, checkpoint, image_size: int):
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_diffusers_schedular
def create_diffusers_schedular(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
num_train_timesteps=original_config["model"]["params"]["timesteps"],
beta_start=original_config["model"]["params"]["linear_start"],
beta_end=original_config["model"]["params"]["linear_end"],
beta_schedule="scaled_linear",
)
return schedular
@@ -674,17 +673,17 @@ def create_transformers_vocoder_config(original_config):
"""
Creates a config for transformers SpeechT5HifiGan based on the config of the vocoder model.
"""
vocoder_params = original_config.model.params.vocoder_config.params
vocoder_params = original_config["model"]["params"]["vocoder_config"]["params"]
config = {
"model_in_dim": vocoder_params.num_mels,
"sampling_rate": vocoder_params.sampling_rate,
"upsample_initial_channel": vocoder_params.upsample_initial_channel,
"upsample_rates": list(vocoder_params.upsample_rates),
"upsample_kernel_sizes": list(vocoder_params.upsample_kernel_sizes),
"resblock_kernel_sizes": list(vocoder_params.resblock_kernel_sizes),
"model_in_dim": vocoder_params["num_mels"],
"sampling_rate": vocoder_params["sampling_rate"],
"upsample_initial_channel": vocoder_params["upsample_initial_channel"],
"upsample_rates": list(vocoder_params["upsample_rates"]),
"upsample_kernel_sizes": list(vocoder_params["upsample_kernel_sizes"]),
"resblock_kernel_sizes": list(vocoder_params["resblock_kernel_sizes"]),
"resblock_dilation_sizes": [
list(resblock_dilation) for resblock_dilation in vocoder_params.resblock_dilation_sizes
list(resblock_dilation) for resblock_dilation in vocoder_params["resblock_dilation_sizes"]
],
"normalize_before": False,
}
@@ -823,12 +822,6 @@ def load_pipeline_from_original_MusicLDM_ckpt(
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
return: An MusicLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if from_safetensors:
from safetensors import safe_open
@@ -848,9 +841,8 @@ def load_pipeline_from_original_MusicLDM_ckpt(
if original_config_file is None:
original_config = DEFAULT_CONFIG
original_config = OmegaConf.create(original_config)
else:
original_config = OmegaConf.load(original_config_file)
original_config = yaml.safe_load(original_config_file)
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
@@ -874,9 +866,9 @@ def load_pipeline_from_original_MusicLDM_ckpt(
if image_size is None:
image_size = 512
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
num_train_timesteps = original_config["model"]["params"]["timesteps"]
beta_start = original_config["model"]["params"]["linear_start"]
beta_end = original_config["model"]["params"]["linear_end"]
scheduler = DDIMScheduler(
beta_end=beta_end,
+2 -2
View File
@@ -3,7 +3,7 @@ import io
import requests
import torch
from omegaconf import OmegaConf
import yaml
from diffusers import AutoencoderKL
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
@@ -126,7 +126,7 @@ def vae_pt_to_vae_diffuser(
)
io_obj = io.BytesIO(r.content)
original_config = OmegaConf.load(io_obj)
original_config = yaml.safe_load(io_obj)
image_size = 512
device = "cuda" if torch.cuda.is_available() else "cpu"
if checkpoint_path.endswith("safetensors"):
+39 -48
View File
@@ -45,51 +45,45 @@ from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionSchedu
from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings
try:
from omegaconf import OmegaConf
except ImportError:
raise ImportError(
"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"
" OmegaConf`."
)
# vqvae model
PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]
def vqvae_model_from_original_config(original_config):
assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers."
assert (
original_config["target"] in PORTED_VQVAES
), f"{original_config['target']} has not yet been ported to diffusers."
original_config = original_config.params
original_config = original_config["params"]
original_encoder_config = original_config.encoder_config.params
original_decoder_config = original_config.decoder_config.params
original_encoder_config = original_config["encoder_config"]["params"]
original_decoder_config = original_config["decoder_config"]["params"]
in_channels = original_encoder_config.in_channels
out_channels = original_decoder_config.out_ch
in_channels = original_encoder_config["in_channels"]
out_channels = original_decoder_config["out_ch"]
down_block_types = get_down_block_types(original_encoder_config)
up_block_types = get_up_block_types(original_decoder_config)
assert original_encoder_config.ch == original_decoder_config.ch
assert original_encoder_config.ch_mult == original_decoder_config.ch_mult
assert original_encoder_config["ch"] == original_decoder_config["ch"]
assert original_encoder_config["ch_mult"] == original_decoder_config["ch_mult"]
block_out_channels = tuple(
[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]
[original_encoder_config["ch"] * a_ch_mult for a_ch_mult in original_encoder_config["ch_mult"]]
)
assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks
layers_per_block = original_encoder_config.num_res_blocks
assert original_encoder_config["num_res_blocks"] == original_decoder_config["num_res_blocks"]
layers_per_block = original_encoder_config["num_res_blocks"]
assert original_encoder_config.z_channels == original_decoder_config.z_channels
latent_channels = original_encoder_config.z_channels
assert original_encoder_config["z_channels"] == original_decoder_config["z_channels"]
latent_channels = original_encoder_config["z_channels"]
num_vq_embeddings = original_config.n_embed
num_vq_embeddings = original_config["n_embed"]
# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion
norm_num_groups = 32
e_dim = original_config.embed_dim
e_dim = original_config["embed_dim"]
model = VQModel(
in_channels=in_channels,
@@ -108,9 +102,9 @@ def vqvae_model_from_original_config(original_config):
def get_down_block_types(original_encoder_config):
attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions)
num_resolutions = len(original_encoder_config.ch_mult)
resolution = coerce_resolution(original_encoder_config.resolution)
attn_resolutions = coerce_attn_resolutions(original_encoder_config["attn_resolutions"])
num_resolutions = len(original_encoder_config["ch_mult"])
resolution = coerce_resolution(original_encoder_config["resolution"])
curr_res = resolution
down_block_types = []
@@ -129,9 +123,9 @@ def get_down_block_types(original_encoder_config):
def get_up_block_types(original_decoder_config):
attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions)
num_resolutions = len(original_decoder_config.ch_mult)
resolution = coerce_resolution(original_decoder_config.resolution)
attn_resolutions = coerce_attn_resolutions(original_decoder_config["attn_resolutions"])
num_resolutions = len(original_decoder_config["ch_mult"])
resolution = coerce_resolution(original_decoder_config["resolution"])
curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]
up_block_types = []
@@ -150,7 +144,7 @@ def get_up_block_types(original_decoder_config):
def coerce_attn_resolutions(attn_resolutions):
attn_resolutions = OmegaConf.to_object(attn_resolutions)
attn_resolutions = list(attn_resolutions)
attn_resolutions_ = []
for ar in attn_resolutions:
if isinstance(ar, (list, tuple)):
@@ -161,7 +155,6 @@ def coerce_attn_resolutions(attn_resolutions):
def coerce_resolution(resolution):
resolution = OmegaConf.to_object(resolution)
if isinstance(resolution, int):
resolution = [resolution, resolution] # H, W
elif isinstance(resolution, (tuple, list)):
@@ -472,18 +465,18 @@ def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config
):
assert (
original_diffusion_config.target in PORTED_DIFFUSIONS
), f"{original_diffusion_config.target} has not yet been ported to diffusers."
original_diffusion_config["target"] in PORTED_DIFFUSIONS
), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
assert (
original_transformer_config.target in PORTED_TRANSFORMERS
), f"{original_transformer_config.target} has not yet been ported to diffusers."
original_transformer_config["target"] in PORTED_TRANSFORMERS
), f"{original_transformer_config['target']} has not yet been ported to diffusers."
assert (
original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS
), f"{original_content_embedding_config.target} has not yet been ported to diffusers."
original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
original_diffusion_config = original_diffusion_config.params
original_transformer_config = original_transformer_config.params
original_content_embedding_config = original_content_embedding_config.params
original_diffusion_config = original_diffusion_config["params"]
original_transformer_config = original_transformer_config["params"]
original_content_embedding_config = original_content_embedding_config["params"]
inner_dim = original_transformer_config["n_embd"]
@@ -689,13 +682,11 @@ def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_fee
def read_config_file(filename):
# The yaml file contains annotations that certain values should
# loaded as tuples. By default, OmegaConf will panic when reading
# these. Instead, we can manually read the yaml with the FullLoader and then
# construct the OmegaConf object.
# loaded as tuples.
with open(filename) as f:
original_config = yaml.load(f, FullLoader)
return OmegaConf.create(original_config)
return original_config
# We take separate arguments for the vqvae because the ITHQ vqvae config file
@@ -792,9 +783,9 @@ if __name__ == "__main__":
original_config = read_config_file(args.original_config_file).model
diffusion_config = original_config.params.diffusion_config
transformer_config = original_config.params.diffusion_config.params.transformer_config
content_embedding_config = original_config.params.diffusion_config.params.content_emb_config
diffusion_config = original_config["params"]["diffusion_config"]
transformer_config = original_config["params"]["diffusion_config"]["params"]["transformer_config"]
content_embedding_config = original_config["params"]["diffusion_config"]["params"]["content_emb_config"]
pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)
@@ -831,7 +822,7 @@ if __name__ == "__main__":
# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
# model, so we pull them off the checkpoint before the checkpoint is deleted.
learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf
learnable_classifier_free_sampling_embeddings = diffusion_config["params"].learnable_cf
if learnable_classifier_free_sampling_embeddings:
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
+47 -43
View File
@@ -14,6 +14,7 @@ $ python convert_zero123_to_diffusers.py \
import argparse
import torch
import yaml
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from pipeline_zero1to3 import CCProjection, Zero1to3StableDiffusionPipeline
@@ -38,51 +39,54 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
Creates a config for the diffusers based on the config of the LDM model.
"""
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
else:
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
unet_params = original_config.model.params.unet_config.params
if (
"unet_config" in original_config["model"]["params"]
and original_config["model"]["params"]["unet_config"] is not None
):
unet_params = original_config["model"]["params"]["unet_config"]["params"]
else:
unet_params = original_config.model.params.network_config.params
unet_params = original_config["model"]["params"]["network_config"]["params"]
vae_params = original_config.model.params.first_stage_config.params.ddconfig
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
if unet_params.transformer_depth is not None:
if unet_params["transformer_depth"] is not None:
transformer_layers_per_block = (
unet_params.transformer_depth
if isinstance(unet_params.transformer_depth, int)
else list(unet_params.transformer_depth)
unet_params["transformer_depth"]
if isinstance(unet_params["transformer_depth"], int)
else list(unet_params["transformer_depth"])
)
else:
transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
class_embed_type = None
addition_embed_type = None
@@ -90,13 +94,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
projection_class_embeddings_input_dim = None
context_dim = None
if unet_params.context_dim is not None:
if unet_params["context_dim"] is not None:
context_dim = (
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
unet_params["context_dim"]
if isinstance(unet_params["context_dim"], int)
else unet_params["context_dim"][0]
)
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
if unet_params["num_classes"] == "sequential":
if context_dim in [2048, 1280]:
# SDXL
addition_embed_type = "text_time"
@@ -104,16 +110,16 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
else:
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
else:
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params["num_classes"]}")
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"in_channels": unet_params["in_channels"],
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
@@ -125,9 +131,9 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
}
if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
config["conditioning_channels"] = unet_params["hint_channels"]
else:
config["out_channels"] = unet_params.out_channels
config["out_channels"] = unet_params["out_channels"]
config["up_block_types"] = tuple(up_block_types)
return config
@@ -487,22 +493,22 @@ def create_vae_diffusers_config(original_config, image_size: int):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config.model.params.first_stage_config.params.ddconfig
_ = original_config.model.params.first_stage_config.params.embed_dim
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = {
"sample_size": image_size,
"in_channels": vae_params.in_channels,
"out_channels": vae_params.out_ch,
"in_channels": vae_params["in_channels"],
"out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels,
"layers_per_block": vae_params.num_res_blocks,
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
}
return config
@@ -679,18 +685,16 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex
del ckpt
torch.cuda.empty_cache()
from omegaconf import OmegaConf
original_config = OmegaConf.load(original_config_file)
original_config.model.params.cond_stage_config.target.split(".")[-1]
original_config = yaml.safe_load(original_config_file)
original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
num_in_channels = 8
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
prediction_type = "epsilon"
image_size = 256
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
beta_start = getattr(original_config["model"]["params"], "linear_start", None) or 0.02
beta_end = getattr(original_config["model"]["params"], "linear_end", None) or 0.085
scheduler = DDIMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
@@ -721,10 +725,10 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex
if (
"model" in original_config
and "params" in original_config.model
and "scale_factor" in original_config.model.params
and "params" in original_config["model"]
and "scale_factor" in original_config["model"]["params"]
):
vae_scaling_factor = original_config.model.params.scale_factor
vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
else:
vae_scaling_factor = 0.18215 # default SD scaling factor
-2
View File
@@ -110,7 +110,6 @@ _deps = [
"note_seq",
"librosa",
"numpy",
"omegaconf",
"parameterized",
"peft>=0.6.0",
"protobuf>=3.20.3,<4",
@@ -213,7 +212,6 @@ extras["test"] = deps_list(
"invisible-watermark",
"k-diffusion",
"librosa",
"omegaconf",
"parameterized",
"pytest",
"pytest-timeout",
+4 -2
View File
@@ -153,6 +153,7 @@ else:
"LCMScheduler",
"PNDMScheduler",
"RePaintScheduler",
"SASolverScheduler",
"SchedulerMixin",
"ScoreSdeVeScheduler",
"UnCLIPScheduler",
@@ -381,7 +382,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["models.controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
@@ -530,6 +531,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LCMScheduler,
PNDMScheduler,
RePaintScheduler,
SASolverScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
UnCLIPScheduler,
@@ -709,7 +711,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .models.controlnet_flax import FlaxControlNetModel
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.unets.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
@@ -22,7 +22,6 @@ deps = {
"note_seq": "note_seq",
"librosa": "librosa",
"numpy": "numpy",
"omegaconf": "omegaconf",
"parameterized": "parameterized",
"peft": "peft>=0.6.0",
"protobuf": "protobuf>=3.20.3,<4",
@@ -16,7 +16,7 @@ import numpy as np
import torch
import tqdm
from ...models.unet_1d import UNet1DModel
from ...models.unets.unet_1d import UNet1DModel
from ...pipelines import DiffusionPipeline
from ...utils.dummy_pt_objects import DDPMScheduler
from ...utils.torch_utils import randn_tensor
+6 -4
View File
@@ -54,12 +54,13 @@ if is_transformers_available():
_import_structure = {}
if is_torch_available():
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]
_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
_import_structure["single_file"].extend(["FromSingleFileMixin"])
_import_structure["single_file"] = ["FromSingleFileMixin"]
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
@@ -69,7 +70,8 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
from .autoencoder import FromOriginalVAEMixin
from .controlnet import FromOriginalControlNetMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
+126
View File
@@ -0,0 +1,126 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub.utils import validate_hf_hub_args
from .single_file_utils import (
create_diffusers_vae_model_from_ldm,
fetch_ldm_config_and_checkpoint,
)
class FromOriginalVAEMixin:
"""
Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
<Tip warning={true}>
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
a VAE from SDXL or a Stable Diffusion v2 model or higher.
</Tip>
Examples:
```py
from diffusers import AutoencoderKL
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
model = AutoencoderKL.from_single_file(url)
```
"""
original_config_file = kwargs.pop("original_config_file", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", True)
class_name = cls.__name__
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,
original_config_file=original_config_file,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
use_safetensors=use_safetensors,
cache_dir=cache_dir,
)
image_size = kwargs.pop("image_size", None)
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
vae = component["vae"]
if torch_dtype is not None:
vae = vae.to(torch_dtype)
return vae
+127
View File
@@ -0,0 +1,127 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub.utils import validate_hf_hub_args
from .single_file_utils import (
create_diffusers_controlnet_model_from_ldm,
fetch_ldm_config_and_checkpoint,
)
class FromOriginalControlNetMixin:
"""
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
Examples:
```py
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
model = ControlNetModel.from_single_file(url)
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
```
"""
original_config_file = kwargs.pop("original_config_file", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", True)
class_name = cls.__name__
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,
original_config_file=original_config_file,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
use_safetensors=use_safetensors,
cache_dir=cache_dir,
)
upcast_attention = kwargs.pop("upcast_attention", False)
image_size = kwargs.pop("image_size", None)
component = create_diffusers_controlnet_model_from_ldm(
class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size
)
controlnet = component["controlnet"]
if torch_dtype is not None:
controlnet = controlnet.to(torch_dtype)
return controlnet
+2 -2
View File
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
from typing import Dict, Union
import torch
@@ -138,7 +138,7 @@ class IPAdapterMixin:
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=os.path.join(subfolder, "image_encoder"),
subfolder=Path(subfolder, "image_encoder").as_posix(),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
+4 -2
View File
@@ -14,6 +14,7 @@
import inspect
import os
from contextlib import nullcontext
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import safetensors
@@ -960,8 +961,9 @@ class LoraLoaderMixin:
else:
weight_name = LORA_WEIGHT_NAME
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
save_path = Path(save_directory, weight_name).as_posix()
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
def unload_lora_weights(self):
"""
+173 -532
View File
@@ -11,45 +11,132 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from io import BytesIO
from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import (
deprecate,
is_accelerate_available,
is_omegaconf_available,
is_transformers_available,
logging,
from ..utils import is_transformers_available, logging
from .single_file_utils import (
create_diffusers_unet_model_from_ldm,
create_diffusers_vae_model_from_ldm,
create_scheduler_from_ldm,
create_text_encoders_and_tokenizers_from_ldm,
fetch_ldm_config_and_checkpoint,
infer_model_type,
)
from ..utils.import_utils import BACKENDS_MAPPING
if is_transformers_available():
pass
if is_accelerate_available():
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
# Pipelines that support the SDXL Refiner checkpoint
REFINER_PIPELINES = [
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
]
if is_transformers_available():
from transformers import AutoFeatureExtractor
def build_sub_model_components(
pipeline_components,
pipeline_class_name,
component_name,
original_config,
checkpoint,
local_files_only=False,
load_safety_checker=False,
model_type=None,
image_size=None,
**kwargs,
):
if component_name in pipeline_components:
return {}
if component_name == "unet":
num_in_channels = kwargs.pop("num_in_channels", None)
unet_components = create_diffusers_unet_model_from_ldm(
pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size
)
return unet_components
if component_name == "vae":
vae_components = create_diffusers_vae_model_from_ldm(
pipeline_class_name, original_config, checkpoint, image_size
)
return vae_components
if component_name == "scheduler":
scheduler_type = kwargs.get("scheduler_type", "ddim")
prediction_type = kwargs.get("prediction_type", None)
scheduler_components = create_scheduler_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
scheduler_type=scheduler_type,
prediction_type=prediction_type,
model_type=model_type,
)
return scheduler_components
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
text_encoder_components = create_text_encoders_and_tokenizers_from_ldm(
original_config,
checkpoint,
model_type=model_type,
local_files_only=local_files_only,
)
return text_encoder_components
if component_name == "safety_checker":
if load_safety_checker:
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
else:
safety_checker = None
return {"safety_checker": safety_checker}
if component_name == "feature_extractor":
if load_safety_checker:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
else:
feature_extractor = None
return {"feature_extractor": feature_extractor}
return
def set_additional_components(
pipeline_class_name,
original_config,
model_type=None,
):
components = {}
if pipeline_class_name in REFINER_PIPELINES:
model_type = infer_model_type(original_config, model_type=model_type)
is_refiner = model_type == "SDXL-Refiner"
components.update(
{
"requires_aesthetics_score": is_refiner,
"force_zeros_for_empty_prompt": False if is_refiner else True,
}
)
return components
class FromSingleFileMixin:
"""
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
"""
@classmethod
def from_ckpt(cls, *args, **kwargs):
deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
return cls.from_single_file(*args, **kwargs)
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
@@ -64,8 +151,7 @@ class FromSingleFileMixin:
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -91,42 +177,6 @@ class FromSingleFileMixin:
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
extract_ema (`bool`, *optional*, defaults to `False`):
Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
higher quality images for inference. Non-EMA weights are usually better for continuing finetuning.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
prediction_type (`str`, *optional*):
The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
num_in_channels (`int`, *optional*, defaults to `None`):
The number of input channels. If `None`, it is automatically inferred.
scheduler_type (`str`, *optional*, defaults to `"pndm"`):
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
"ddim"]`.
load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not.
text_encoder ([`~transformers.CLIPTextModel`], *optional*, defaults to `None`):
An instance of `CLIPTextModel` to use, specifically the
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
vae (`AutoencoderKL`, *optional*, defaults to `None`):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
of `CLIPTokenizer` by itself if needed.
original_config_file (`str`):
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be
automatically inferred by looking for a key that only exists in SD2.0 models.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
Examples:
```py
@@ -149,489 +199,80 @@ class FromSingleFileMixin:
>>> pipeline.to("cuda")
```
"""
# import here to avoid circular dependency
from ..pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
original_config_file = kwargs.pop("original_config_file", None)
config_files = kwargs.pop("config_files", None)
cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
extract_ema = kwargs.pop("extract_ema", False)
image_size = kwargs.pop("image_size", None)
scheduler_type = kwargs.pop("scheduler_type", "pndm")
num_in_channels = kwargs.pop("num_in_channels", None)
upcast_attention = kwargs.pop("upcast_attention", None)
load_safety_checker = kwargs.pop("load_safety_checker", True)
prediction_type = kwargs.pop("prediction_type", None)
text_encoder = kwargs.pop("text_encoder", None)
text_encoder_2 = kwargs.pop("text_encoder_2", None)
vae = kwargs.pop("vae", None)
controlnet = kwargs.pop("controlnet", None)
adapter = kwargs.pop("adapter", None)
tokenizer = kwargs.pop("tokenizer", None)
tokenizer_2 = kwargs.pop("tokenizer_2", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", True)
use_safetensors = kwargs.pop("use_safetensors", None)
class_name = cls.__name__
pipeline_name = cls.__name__
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors"
if from_safetensors and use_safetensors is False:
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
# TODO: For now we only support stable diffusion
stable_unclip = None
model_type = None
if pipeline_name in [
"StableDiffusionControlNetPipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
]:
from ..models.controlnet import ControlNetModel
from ..pipelines.controlnet.multicontrolnet import MultiControlNetModel
# list/tuple or a single instance of ControlNetModel or MultiControlNetModel
if not (
isinstance(controlnet, (ControlNetModel, MultiControlNetModel))
or isinstance(controlnet, (list, tuple))
and isinstance(controlnet[0], ControlNetModel)
):
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
elif "StableDiffusion" in pipeline_name:
# Model type will be inferred from the checkpoint.
pass
elif pipeline_name == "StableUnCLIPPipeline":
model_type = "FrozenOpenCLIPEmbedder"
stable_unclip = "txt2img"
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
model_type = "FrozenOpenCLIPEmbedder"
stable_unclip = "img2img"
elif pipeline_name == "PaintByExamplePipeline":
model_type = "PaintByExample"
elif pipeline_name == "LDMTextToImagePipeline":
model_type = "LDMTextToImage"
else:
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
# remove huggingface url
has_valid_url_prefix = False
valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
for prefix in valid_url_prefixes:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
has_valid_url_prefix = True
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
if not ckpt_path.is_file():
if not has_valid_url_prefix:
raise ValueError(
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}"
)
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = "/".join(ckpt_path.parts[:2])
file_path = "/".join(ckpt_path.parts[2:])
if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]
if file_path.startswith("main/"):
file_path = file_path[len("main/") :]
pretrained_model_link_or_path = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
force_download=force_download,
)
pipe = download_from_original_stable_diffusion_ckpt(
pretrained_model_link_or_path,
pipeline_class=cls,
model_type=model_type,
stable_unclip=stable_unclip,
controlnet=controlnet,
adapter=adapter,
from_safetensors=from_safetensors,
extract_ema=extract_ema,
image_size=image_size,
scheduler_type=scheduler_type,
num_in_channels=num_in_channels,
upcast_attention=upcast_attention,
load_safety_checker=load_safety_checker,
prediction_type=prediction_type,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
vae=vae,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,
original_config_file=original_config_file,
config_files=config_files,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
use_safetensors=use_safetensors,
cache_dir=cache_dir,
)
from ..pipelines.pipeline_utils import _get_pipeline_class
pipeline_class = _get_pipeline_class(
cls,
config=None,
cache_dir=cache_dir,
)
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
model_type = kwargs.pop("model_type", None)
image_size = kwargs.pop("image_size", None)
load_safety_checker = (kwargs.pop("load_safety_checker", False)) or (
passed_class_obj.get("safety_checker", None) is not None
)
init_kwargs = {}
for name in expected_modules:
if name in passed_class_obj:
init_kwargs[name] = passed_class_obj[name]
else:
components = build_sub_model_components(
init_kwargs,
class_name,
name,
original_config,
checkpoint,
model_type=model_type,
image_size=image_size,
load_safety_checker=load_safety_checker,
local_files_only=local_files_only,
**kwargs,
)
if not components:
continue
init_kwargs.update(components)
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
if additional_components:
init_kwargs.update(additional_components)
init_kwargs.update(passed_pipe_kwargs)
pipe = pipeline_class(**init_kwargs)
if torch_dtype is not None:
pipe.to(dtype=torch_dtype)
return pipe
class FromOriginalVAEMixin:
"""
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into an [`AutoencoderKL`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
<Tip warning={true}>
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
a VAE from SDXL or a Stable Diffusion v2 model or higher.
</Tip>
Examples:
```py
from diffusers import AutoencoderKL
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
model = AutoencoderKL.from_single_file(url)
```
"""
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
from ..models import AutoencoderKL
# import here to avoid circular dependency
from ..pipelines.stable_diffusion.convert_from_ckpt import (
convert_ldm_vae_checkpoint,
create_vae_diffusers_config,
)
config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
image_size = kwargs.pop("image_size", None)
scaling_factor = kwargs.pop("scaling_factor", None)
kwargs.pop("upcast_attention", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None)
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors"
if from_safetensors and use_safetensors is False:
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
# remove huggingface url
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
if not ckpt_path.is_file():
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = "/".join(ckpt_path.parts[:2])
file_path = "/".join(ckpt_path.parts[2:])
if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]
if file_path.startswith("main/"):
file_path = file_path[len("main/") :]
pretrained_model_link_or_path = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
force_download=force_download,
)
if from_safetensors:
from safetensors import safe_open
checkpoint = {}
with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
else:
checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
if config_file is None:
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(config_file)
# default to sd-v1-5
image_size = image_size or 512
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
if scaling_factor is None:
if (
"model" in original_config
and "params" in original_config.model
and "scale_factor" in original_config.model.params
):
vae_scaling_factor = original_config.model.params.scale_factor
else:
vae_scaling_factor = 0.18215 # default SD scaling factor
vae_config["scaling_factor"] = vae_scaling_factor
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
vae = AutoencoderKL(**vae_config)
if is_accelerate_available():
from ..models.modeling_utils import load_model_dict_into_meta
load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
else:
vae.load_state_dict(converted_vae_checkpoint)
if torch_dtype is not None:
vae.to(dtype=torch_dtype)
return vae
class FromOriginalControlnetMixin:
"""
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
upcast_attention (`bool`, *optional*, defaults to `None`):
Whether the attention computation should always be upcasted.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (for example the pipeline components of the
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
method. See example below for more information.
Examples:
```py
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
model = ControlNetModel.from_single_file(url)
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
```
"""
# import here to avoid circular dependency
from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
num_in_channels = kwargs.pop("num_in_channels", None)
use_linear_projection = kwargs.pop("use_linear_projection", None)
revision = kwargs.pop("revision", None)
extract_ema = kwargs.pop("extract_ema", False)
image_size = kwargs.pop("image_size", None)
upcast_attention = kwargs.pop("upcast_attention", None)
torch_dtype = kwargs.pop("torch_dtype", None)
use_safetensors = kwargs.pop("use_safetensors", None)
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
from_safetensors = file_extension == "safetensors"
if from_safetensors and use_safetensors is False:
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
# remove huggingface url
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
if pretrained_model_link_or_path.startswith(prefix):
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
ckpt_path = Path(pretrained_model_link_or_path)
if not ckpt_path.is_file():
# get repo_id and (potentially nested) file path of ckpt in repo
repo_id = "/".join(ckpt_path.parts[:2])
file_path = "/".join(ckpt_path.parts[2:])
if file_path.startswith("blob/"):
file_path = file_path[len("blob/") :]
if file_path.startswith("main/"):
file_path = file_path[len("main/") :]
pretrained_model_link_or_path = hf_hub_download(
repo_id,
filename=file_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
force_download=force_download,
)
if config_file is None:
config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml"
config_file = BytesIO(requests.get(config_url).content)
image_size = image_size or 512
controlnet = download_controlnet_from_original_ckpt(
pretrained_model_link_or_path,
original_config_file=config_file,
image_size=image_size,
extract_ema=extract_ema,
num_in_channels=num_in_channels,
upcast_attention=upcast_attention,
from_safetensors=from_safetensors,
use_linear_projection=use_linear_projection,
)
if torch_dtype is not None:
controlnet.to(dtype=torch_dtype)
return controlnet
File diff suppressed because it is too large Load Diff
+21 -18
View File
@@ -39,19 +39,19 @@ if is_torch_available():
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unet_1d"] = ["UNet1DModel"]
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["uvit_2d"] = ["UVit2DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unets.unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
_import_structure["unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
@@ -73,19 +73,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .uvit_2d import UVit2DModel
from .unets import (
Kandinsky3UNet,
MotionAdapter,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
UVit2DModel,
)
from .vq_model import VQModel
if is_flax_available():
from .controlnet_flax import FlaxControlNetModel
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
from .unets import FlaxUNet2DConditionModel
from .vae_flax import FlaxAutoencoderKL
else:
@@ -157,7 +157,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
self.use_slicing = False
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -181,7 +181,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -216,7 +216,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -448,7 +448,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=dec)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
@@ -472,7 +472,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
@@ -17,13 +17,12 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalVAEMixin
from ...utils import is_torch_version
from ...utils.accelerate_utils import apply_forward_hook
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@@ -162,7 +161,7 @@ class TemporalDecoder(nn.Module):
return sample
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -242,7 +241,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -266,7 +265,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -31,7 +31,7 @@ from ..attention_processor import (
AttnProcessor,
)
from ..modeling_utils import ModelMixin
from ..unet_2d import UNet2DModel
from ..unets.unet_2d import UNet2DModel
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@@ -187,7 +187,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
self.use_slicing = False
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -246,7 +246,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
+1 -1
View File
@@ -22,7 +22,7 @@ from ...utils import BaseOutput, is_torch_version
from ...utils.torch_utils import randn_tensor
from ..activations import get_activation
from ..attention_processor import SpatialNorm
from ..unet_2d_blocks import (
from ..unets.unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,
+14 -8
View File
@@ -19,7 +19,7 @@ from torch import nn
from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlnetMixin
from ..loaders import FromOriginalControlNetMixin
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -30,8 +30,14 @@ from .attention_processor import (
)
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block
from .unet_2d_condition import UNet2DConditionModel
from .unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
get_down_block,
)
from .unets.unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -102,7 +108,7 @@ class ControlNetConditioningEmbedding(nn.Module):
return embedding
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
"""
A ControlNet model.
@@ -509,7 +515,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return controlnet
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -533,7 +539,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -568,7 +574,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -584,7 +590,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
self.set_attn_processor(processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
Enable sliced attention computation.
+4 -4
View File
@@ -23,7 +23,7 @@ from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .modeling_flax_utils import FlaxModelMixin
from .unet_2d_blocks_flax import (
from .unets.unet_2d_blocks_flax import (
FlaxCrossAttnDownBlock2D,
FlaxDownBlock2D,
FlaxUNetMidBlock2DCrossAttn,
@@ -329,14 +329,14 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
channel_order = self.controlnet_conditioning_channel_order
+1 -1
View File
@@ -120,7 +120,7 @@ class DualTransformer2DModel(nn.Module):
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
+5 -3
View File
@@ -32,6 +32,7 @@ from .. import __version__
from ..utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
SAFETENSORS_FILE_EXTENSION,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_add_variant,
@@ -102,10 +103,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
try:
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
return torch.load(checkpoint_file, map_location="cpu")
else:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
else:
return torch.load(checkpoint_file, map_location="cpu")
except Exception as e:
try:
with open(checkpoint_file) as f:
+3 -3
View File
@@ -167,7 +167,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -191,7 +191,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -226,7 +226,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
+1 -1
View File
@@ -286,7 +286,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
+1 -1
View File
@@ -149,7 +149,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
+8 -237
View File
@@ -12,244 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
from ..utils import deprecate
from .unets.unet_1d import UNet1DModel, UNet1DOutput
@dataclass
class UNet1DOutput(BaseOutput):
"""
The output of [`UNet1DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
class UNet1DOutput(UNet1DOutput):
deprecation_message = "Importing `UNet1DOutput` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DOutput`, instead."
deprecate("UNet1DOutput", "0.29", deprecation_message)
class UNet1DModel(ModelMixin, ConfigMixin):
r"""
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
extra_in_channels (`int`, *optional*, defaults to 0):
Number of additional channels to be added to the input of the first down block. Useful for cases where the
input data has more channels than what the model was initially designed for.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
downsample_each_block (`int`, *optional*, defaults to `False`):
Experimental feature for using a UNet without upsampling.
"""
@register_to_config
def __init__(
self,
sample_size: int = 65536,
sample_rate: Optional[int] = None,
in_channels: int = 2,
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
downsample_each_block: bool = False,
):
super().__init__()
self.sample_size = sample_size
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=timestep_input_dim,
time_embed_dim=time_embed_dim,
act_fn=act_fn,
out_dim=block_out_channels[0],
)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
self.out_block = None
# down
output_channel = in_channels
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
if i == 0:
input_channel += extra_in_channels
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_downsample=not is_final_block or downsample_each_block,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = get_mid_block(
mid_block_type,
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
embed_dim=block_out_channels[0],
num_layers=layers_per_block,
add_downsample=downsample_each_block,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
if out_block_type is None:
final_upsample_channels = out_channels
else:
final_upsample_channels = block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = (
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
)
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block,
in_channels=prev_output_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.out_block = get_out_block(
out_block_type=out_block_type,
num_groups_out=num_groups_out,
embed_dim=block_out_channels[0],
out_channels=out_channels,
act_fn=act_fn,
fc_dim=block_out_channels[-1] // 4,
)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]:
r"""
The [`UNet1DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
# 2. down
down_block_res_samples = ()
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
down_block_res_samples += res_samples
# 3. mid
if self.mid_block:
sample = self.mid_block(sample, timestep_embed)
# 4. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1]
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
# 5. post-process
if self.out_block:
sample = self.out_block(sample, timestep_embed)
if not return_dict:
return (sample,)
return UNet1DOutput(sample=sample)
class UNet1DModel(UNet1DModel):
deprecation_message = "Importing `UNet1DModel` from `diffusers.models.unet_1d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d import UNet1DModel`, instead."
deprecate("UNet1DModel", "0.29", deprecation_message)
+128 -627
View File
@@ -11,616 +11,112 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from .activations import get_activation
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
class DownResnetBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
conv_shortcut: bool = False,
temb_channels: int = 32,
groups: int = 32,
groups_out: Optional[int] = None,
non_linearity: Optional[str] = None,
time_embedding_norm: str = "default",
output_scale_factor: float = 1.0,
add_downsample: bool = True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.add_downsample = add_downsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
output_states = ()
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.downsample is not None:
hidden_states = self.downsample(hidden_states)
return hidden_states, output_states
class UpResnetBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
temb_channels: int = 32,
groups: int = 32,
groups_out: Optional[int] = None,
non_linearity: Optional[str] = None,
time_embedding_norm: str = "default",
output_scale_factor: float = 1.0,
add_upsample: bool = True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.time_embedding_norm = time_embedding_norm
self.add_upsample = add_upsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None
if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
hidden_states = self.upsample(hidden_states)
return hidden_states
class ValueFunctionMidBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.embed_dim = embed_dim
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
x = self.res1(x, temb)
x = self.down1(x)
x = self.res2(x, temb)
x = self.down2(x)
return x
class MidResTemporalBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
embed_dim: int,
num_layers: int = 1,
add_downsample: bool = False,
add_upsample: bool = False,
non_linearity: Optional[str] = None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.add_downsample = add_downsample
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None
if add_upsample:
self.upsample = Downsample1D(out_channels, use_conv=True)
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True)
if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.upsample:
hidden_states = self.upsample(hidden_states)
if self.downsample:
self.downsample = self.downsample(hidden_states)
return hidden_states
class OutConv1DBlock(nn.Module):
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_act(hidden_states)
hidden_states = self.final_conv1d_2(hidden_states)
return hidden_states
class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
super().__init__()
self.final_block = nn.ModuleList(
[
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
get_activation(act_fn),
nn.Linear(fc_dim // 2, 1),
]
)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block:
hidden_states = layer(hidden_states)
return hidden_states
_kernels = {
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
"lanczos3": [
0.003689131001010537,
0.015056144446134567,
-0.03399861603975296,
-0.066637322306633,
0.13550527393817902,
0.44638532400131226,
0.44638532400131226,
0.13550527393817902,
-0.066637322306633,
-0.03399861603975296,
0.015056144446134567,
0.003689131001010537,
],
}
class Downsample1d(nn.Module):
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv1d(hidden_states, weight, stride=2)
class Upsample1d(nn.Module):
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
class SelfAttention1d(nn.Module):
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
super().__init__()
self.channels = in_channels
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
self.num_heads = n_head
self.query = nn.Linear(self.channels, self.channels)
self.key = nn.Linear(self.channels, self.channels)
self.value = nn.Linear(self.channels, self.channels)
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
self.dropout = nn.Dropout(dropout_rate, inplace=True)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = hidden_states
batch, channel_dim, seq = hidden_states.shape
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores, dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.dropout(hidden_states)
output = hidden_states + residual
return output
class ResConvBlock(nn.Module):
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
super().__init__()
self.is_last = is_last
self.has_conv_skip = in_channels != out_channels
if self.has_conv_skip:
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
self.gelu_1 = nn.GELU()
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
if not self.is_last:
self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU()
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states)
hidden_states = self.group_norm_1(hidden_states)
hidden_states = self.gelu_1(hidden_states)
hidden_states = self.conv_2(hidden_states)
from ..utils import deprecate
from .unets.unet_1d_blocks import (
AttnDownBlock1D,
AttnUpBlock1D,
DownBlock1D,
DownBlock1DNoSkip,
DownResnetBlock1D,
Downsample1d,
MidResTemporalBlock1D,
OutConv1DBlock,
OutValueFunctionBlock,
ResConvBlock,
SelfAttention1d,
UNetMidBlock1D,
UpBlock1D,
UpBlock1DNoSkip,
UpResnetBlock1D,
Upsample1d,
ValueFunctionMidBlock1D,
)
if not self.is_last:
hidden_states = self.group_norm_2(hidden_states)
hidden_states = self.gelu_2(hidden_states)
output = hidden_states + residual
return output
class DownResnetBlock1D(DownResnetBlock1D):
deprecation_message = "Importing `DownResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownResnetBlock1D`, instead."
deprecate("DownResnetBlock1D", "0.29", deprecation_message)
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
super().__init__()
class UpResnetBlock1D(UpResnetBlock1D):
deprecation_message = "Importing `UpResnetBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpResnetBlock1D`, instead."
deprecate("UpResnetBlock1D", "0.29", deprecation_message)
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.up = Upsample1d(kernel="cubic")
class ValueFunctionMidBlock1D(ValueFunctionMidBlock1D):
deprecation_message = "Importing `ValueFunctionMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ValueFunctionMidBlock1D`, instead."
deprecate("ValueFunctionMidBlock1D", "0.29", deprecation_message)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
class OutConv1DBlock(OutConv1DBlock):
deprecation_message = "Importing `OutConv1DBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutConv1DBlock`, instead."
deprecate("OutConv1DBlock", "0.29", deprecation_message)
hidden_states = self.up(hidden_states)
return hidden_states
class OutValueFunctionBlock(OutValueFunctionBlock):
deprecation_message = "Importing `OutValueFunctionBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import OutValueFunctionBlock`, instead."
deprecate("OutValueFunctionBlock", "0.29", deprecation_message)
class AttnDownBlock1D(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
class Downsample1d(Downsample1d):
deprecation_message = "Importing `Downsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Downsample1d`, instead."
deprecate("Downsample1d", "0.29", deprecation_message)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
class Upsample1d(Upsample1d):
deprecation_message = "Importing `Upsample1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import Upsample1d`, instead."
deprecate("Upsample1d", "0.29", deprecation_message)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
return hidden_states, (hidden_states,)
class DownBlock1D(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states, (hidden_states,)
class DownBlock1DNoSkip(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states, (hidden_states,)
class AttnUpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
class SelfAttention1d(SelfAttention1d):
deprecation_message = "Importing `SelfAttention1d` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import SelfAttention1d`, instead."
deprecate("SelfAttention1d", "0.29", deprecation_message)
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
class ResConvBlock(ResConvBlock):
deprecation_message = "Importing `ResConvBlock` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import ResConvBlock`, instead."
deprecate("ResConvBlock", "0.29", deprecation_message)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
class UNetMidBlock1D(UNetMidBlock1D):
deprecation_message = "Importing `UNetMidBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UNetMidBlock1D`, instead."
deprecate("UNetMidBlock1D", "0.29", deprecation_message)
hidden_states = self.up(hidden_states)
return hidden_states
class AttnDownBlock1D(AttnDownBlock1D):
deprecation_message = "Importing `AttnDownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnDownBlock1D`, instead."
deprecate("AttnDownBlock1D", "0.29", deprecation_message)
class UpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
class DownBlock1D(DownBlock1D):
deprecation_message = "Importing `DownBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1D`, instead."
deprecate("DownBlock1D", "0.29", deprecation_message)
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
class DownBlock1DNoSkip(DownBlock1DNoSkip):
deprecation_message = "Importing `DownBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import DownBlock1DNoSkip`, instead."
deprecate("DownBlock1DNoSkip", "0.29", deprecation_message)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
class AttnUpBlock1D(AttnUpBlock1D):
deprecation_message = "Importing `AttnUpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import AttnUpBlock1D`, instead."
deprecate("AttnUpBlock1D", "0.29", deprecation_message)
hidden_states = self.up(hidden_states)
return hidden_states
class UpBlock1D(UpBlock1D):
deprecation_message = "Importing `UpBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1D`, instead."
deprecate("UpBlock1D", "0.29", deprecation_message)
class UpBlock1DNoSkip(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
class UpBlock1DNoSkip(UpBlock1DNoSkip):
deprecation_message = "Importing `UpBlock1DNoSkip` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import UpBlock1DNoSkip`, instead."
deprecate("UpBlock1DNoSkip", "0.29", deprecation_message)
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
]
self.resnets = nn.ModuleList(resnets)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
class MidResTemporalBlock1D(MidResTemporalBlock1D):
deprecation_message = "Importing `MidResTemporalBlock1D` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import MidResTemporalBlock1D`, instead."
deprecate("MidResTemporalBlock1D", "0.29", deprecation_message)
def get_down_block(
@@ -630,42 +126,38 @@ def get_down_block(
out_channels: int,
temb_channels: int,
add_downsample: bool,
) -> DownBlockType:
if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
elif down_block_type == "DownBlock1D":
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "AttnDownBlock1D":
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "DownBlock1DNoSkip":
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
raise ValueError(f"{down_block_type} does not exist.")
):
deprecation_message = "Importing `get_down_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_down_block`, instead."
deprecate("get_down_block", "0.29", deprecation_message)
from .unets.unet_1d_blocks import get_down_block
return get_down_block(
down_block_type=down_block_type,
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
def get_up_block(
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
) -> UpBlockType:
if up_block_type == "UpResnetBlock1D":
return UpResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
)
elif up_block_type == "UpBlock1D":
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "AttnUpBlock1D":
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "UpBlock1DNoSkip":
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
raise ValueError(f"{up_block_type} does not exist.")
):
deprecation_message = "Importing `get_up_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_up_block`, instead."
deprecate("get_up_block", "0.29", deprecation_message)
from .unets.unet_1d_blocks import get_up_block
return get_up_block(
up_block_type=up_block_type,
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
)
def get_mid_block(
@@ -676,27 +168,36 @@ def get_mid_block(
out_channels: int,
embed_dim: int,
add_downsample: bool,
) -> MidBlockType:
if mid_block_type == "MidResTemporalBlock1D":
return MidResTemporalBlock1D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
embed_dim=embed_dim,
add_downsample=add_downsample,
)
elif mid_block_type == "ValueFunctionMidBlock1D":
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
elif mid_block_type == "UNetMidBlock1D":
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
raise ValueError(f"{mid_block_type} does not exist.")
):
deprecation_message = "Importing `get_mid_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_mid_block`, instead."
deprecate("get_mid_block", "0.29", deprecation_message)
from .unets.unet_1d_blocks import get_mid_block
return get_mid_block(
mid_block_type=mid_block_type,
num_layers=num_layers,
in_channels=in_channels,
mid_channels=mid_channels,
out_channels=out_channels,
embed_dim=embed_dim,
add_downsample=add_downsample,
)
def get_out_block(
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
) -> Optional[OutBlockType]:
if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction":
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
return None
):
deprecation_message = "Importing `get_out_block` from `diffusers.models.unet_1d_blocks` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_1d_blocks import get_out_block`, instead."
deprecate("get_out_block", "0.29", deprecation_message)
from .unets.unet_1d_blocks import get_out_block
return get_out_block(
out_block_type=out_block_type,
num_groups_out=num_groups_out,
embed_dim=embed_dim,
out_channels=out_channels,
act_fn=act_fn,
fc_dim=fc_dim,
)
+8 -327
View File
@@ -11,336 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class UNet2DOutput(BaseOutput):
"""
The output of [`UNet2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
from ..utils import deprecate
from .unets.unet_2d import UNet2DModel, UNet2DOutput
class UNet2DModel(ModelMixin, ConfigMixin):
r"""
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
class UNet2DOutput(UNet2DOutput):
deprecation_message = "Importing `UNet2DOutput` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DOutput`, instead."
deprecate("UNet2DOutput", "0.29", deprecation_message)
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
1)`.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
downsample_type (`str`, *optional*, defaults to `conv`):
The downsample type for downsampling layers. Choose between "conv" and "resnet"
upsample_type (`str`, *optional*, defaults to `conv`):
The upsample type for upsampling layers. Choose between "conv" and "resnet"
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
given number of groups. If left as `None`, the group norm layer will only be created if
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
conditioning with `class_embed_type` equal to `None`.
"""
@register_to_config
def __init__(
self,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int] = (224, 448, 672, 896),
layers_per_block: int = 2,
mid_block_scale_factor: float = 1,
downsample_padding: int = 1,
downsample_type: str = "conv",
upsample_type: str = "conv",
dropout: float = 0.0,
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
attn_norm_num_groups: Optional[int] = None,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
num_train_timesteps: Optional[int] = None,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
# input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
elif time_embedding_type == "learned":
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
dropout=dropout,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
r"""
The [`UNet2DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
elif self.class_embedding is None and class_labels is not None:
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if self.config.time_embedding_type == "fourier":
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
if not return_dict:
return (sample,)
return UNet2DOutput(sample=sample)
class UNet2DModel(UNet2DModel):
deprecation_message = "Importing `UNet2DModel` from `diffusers.models.unet_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.unet_2d import UNet2DModel`, instead."
deprecate("UNet2DModel", "0.29", deprecation_message)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+16
View File
@@ -0,0 +1,16 @@
from ...utils import is_flax_available, is_torch_available
if is_torch_available():
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .uvit_2d import UVit2DModel
if is_flax_available():
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
+255
View File
@@ -0,0 +1,255 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
@dataclass
class UNet1DOutput(BaseOutput):
"""
The output of [`UNet1DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
class UNet1DModel(ModelMixin, ConfigMixin):
r"""
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
extra_in_channels (`int`, *optional*, defaults to 0):
Number of additional channels to be added to the input of the first down block. Useful for cases where the
input data has more channels than what the model was initially designed for.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
downsample_each_block (`int`, *optional*, defaults to `False`):
Experimental feature for using a UNet without upsampling.
"""
@register_to_config
def __init__(
self,
sample_size: int = 65536,
sample_rate: Optional[int] = None,
in_channels: int = 2,
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
downsample_each_block: bool = False,
):
super().__init__()
self.sample_size = sample_size
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=timestep_input_dim,
time_embed_dim=time_embed_dim,
act_fn=act_fn,
out_dim=block_out_channels[0],
)
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
self.out_block = None
# down
output_channel = in_channels
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
if i == 0:
input_channel += extra_in_channels
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_downsample=not is_final_block or downsample_each_block,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = get_mid_block(
mid_block_type,
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
embed_dim=block_out_channels[0],
num_layers=layers_per_block,
add_downsample=downsample_each_block,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
if out_block_type is None:
final_upsample_channels = out_channels
else:
final_upsample_channels = block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = (
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
)
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block,
in_channels=prev_output_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_upsample=not is_final_block,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.out_block = get_out_block(
out_block_type=out_block_type,
num_groups_out=num_groups_out,
embed_dim=block_out_channels[0],
out_channels=out_channels,
act_fn=act_fn,
fc_dim=block_out_channels[-1] // 4,
)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet1DOutput, Tuple]:
r"""
The [`UNet1DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timestep_embed = self.time_proj(timesteps)
if self.config.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))
# 2. down
down_block_res_samples = ()
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
down_block_res_samples += res_samples
# 3. mid
if self.mid_block:
sample = self.mid_block(sample, timestep_embed)
# 4. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1]
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
# 5. post-process
if self.out_block:
sample = self.out_block(sample, timestep_embed)
if not return_dict:
return (sample,)
return UNet1DOutput(sample=sample)
@@ -0,0 +1,702 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..activations import get_activation
from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
class DownResnetBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
conv_shortcut: bool = False,
temb_channels: int = 32,
groups: int = 32,
groups_out: Optional[int] = None,
non_linearity: Optional[str] = None,
time_embedding_norm: str = "default",
output_scale_factor: float = 1.0,
add_downsample: bool = True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.add_downsample = add_downsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
output_states = ()
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.downsample is not None:
hidden_states = self.downsample(hidden_states)
return hidden_states, output_states
class UpResnetBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
temb_channels: int = 32,
groups: int = 32,
groups_out: Optional[int] = None,
non_linearity: Optional[str] = None,
time_embedding_norm: str = "default",
output_scale_factor: float = 1.0,
add_upsample: bool = True,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.time_embedding_norm = time_embedding_norm
self.add_upsample = add_upsample
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None
if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.nonlinearity is not None:
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
hidden_states = self.upsample(hidden_states)
return hidden_states
class ValueFunctionMidBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.embed_dim = embed_dim
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
x = self.res1(x, temb)
x = self.down1(x)
x = self.res2(x, temb)
x = self.down2(x)
return x
class MidResTemporalBlock1D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
embed_dim: int,
num_layers: int = 1,
add_downsample: bool = False,
add_upsample: bool = False,
non_linearity: Optional[str] = None,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.add_downsample = add_downsample
# there will always be at least one resnet
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
for _ in range(num_layers):
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
self.resnets = nn.ModuleList(resnets)
if non_linearity is None:
self.nonlinearity = None
else:
self.nonlinearity = get_activation(non_linearity)
self.upsample = None
if add_upsample:
self.upsample = Downsample1D(out_channels, use_conv=True)
self.downsample = None
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True)
if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample")
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
if self.upsample:
hidden_states = self.upsample(hidden_states)
if self.downsample:
self.downsample = self.downsample(hidden_states)
return hidden_states
class OutConv1DBlock(nn.Module):
def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_act(hidden_states)
hidden_states = self.final_conv1d_2(hidden_states)
return hidden_states
class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
super().__init__()
self.final_block = nn.ModuleList(
[
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
get_activation(act_fn),
nn.Linear(fc_dim // 2, 1),
]
)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block:
hidden_states = layer(hidden_states)
return hidden_states
_kernels = {
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
"cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875],
"lanczos3": [
0.003689131001010537,
0.015056144446134567,
-0.03399861603975296,
-0.066637322306633,
0.13550527393817902,
0.44638532400131226,
0.44638532400131226,
0.13550527393817902,
-0.066637322306633,
-0.03399861603975296,
0.015056144446134567,
0.003689131001010537,
],
}
class Downsample1d(nn.Module):
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv1d(hidden_states, weight, stride=2)
class Upsample1d(nn.Module):
def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)
class SelfAttention1d(nn.Module):
def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
super().__init__()
self.channels = in_channels
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
self.num_heads = n_head
self.query = nn.Linear(self.channels, self.channels)
self.key = nn.Linear(self.channels, self.channels)
self.value = nn.Linear(self.channels, self.channels)
self.proj_attn = nn.Linear(self.channels, self.channels, bias=True)
self.dropout = nn.Dropout(dropout_rate, inplace=True)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = hidden_states
batch, channel_dim, seq = hidden_states.shape
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1]))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
attention_probs = torch.softmax(attention_scores, dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
hidden_states = self.dropout(hidden_states)
output = hidden_states + residual
return output
class ResConvBlock(nn.Module):
def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
super().__init__()
self.is_last = is_last
self.has_conv_skip = in_channels != out_channels
if self.has_conv_skip:
self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, bias=False)
self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2)
self.group_norm_1 = nn.GroupNorm(1, mid_channels)
self.gelu_1 = nn.GELU()
self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2)
if not self.is_last:
self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU()
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states)
hidden_states = self.group_norm_1(hidden_states)
hidden_states = self.gelu_1(hidden_states)
hidden_states = self.conv_2(hidden_states)
if not self.is_last:
hidden_states = self.group_norm_2(hidden_states)
hidden_states = self.gelu_2(hidden_states)
output = hidden_states + residual
return output
class UNetMidBlock1D(nn.Module):
def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
super().__init__()
out_channels = in_channels if out_channels is None else out_channels
# there is always at least one resnet
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.up = Upsample1d(kernel="cubic")
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class AttnDownBlock1D(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
return hidden_states, (hidden_states,)
class DownBlock1D(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
self.down = Downsample1d("cubic")
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = self.down(hidden_states)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states, (hidden_states,)
class DownBlock1DNoSkip(nn.Module):
def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states, (hidden_states,)
class AttnUpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
attentions = [
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(mid_channels, mid_channels // 32),
SelfAttention1d(out_channels, out_channels // 32),
]
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states)
hidden_states = attn(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class UpBlock1D(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels),
]
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
hidden_states = self.up(hidden_states)
return hidden_states
class UpBlock1DNoSkip(nn.Module):
def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
resnets = [
ResConvBlock(2 * in_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, mid_channels),
ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True),
]
self.resnets = nn.ModuleList(resnets)
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
return hidden_states
DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
def get_down_block(
down_block_type: str,
num_layers: int,
in_channels: int,
out_channels: int,
temb_channels: int,
add_downsample: bool,
) -> DownBlockType:
if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
)
elif down_block_type == "DownBlock1D":
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "AttnDownBlock1D":
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
elif down_block_type == "DownBlock1DNoSkip":
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
) -> UpBlockType:
if up_block_type == "UpResnetBlock1D":
return UpResnetBlock1D(
in_channels=in_channels,
num_layers=num_layers,
out_channels=out_channels,
temb_channels=temb_channels,
add_upsample=add_upsample,
)
elif up_block_type == "UpBlock1D":
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "AttnUpBlock1D":
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
elif up_block_type == "UpBlock1DNoSkip":
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
raise ValueError(f"{up_block_type} does not exist.")
def get_mid_block(
mid_block_type: str,
num_layers: int,
in_channels: int,
mid_channels: int,
out_channels: int,
embed_dim: int,
add_downsample: bool,
) -> MidBlockType:
if mid_block_type == "MidResTemporalBlock1D":
return MidResTemporalBlock1D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
embed_dim=embed_dim,
add_downsample=add_downsample,
)
elif mid_block_type == "ValueFunctionMidBlock1D":
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
elif mid_block_type == "UNetMidBlock1D":
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
raise ValueError(f"{mid_block_type} does not exist.")
def get_out_block(
*, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
) -> Optional[OutBlockType]:
if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction":
return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
return None
+346
View File
@@ -0,0 +1,346 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class UNet2DOutput(BaseOutput):
"""
The output of [`UNet2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output from the last layer of the model.
"""
sample: torch.FloatTensor
class UNet2DModel(ModelMixin, ConfigMixin):
r"""
A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
1)`.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
downsample_type (`str`, *optional*, defaults to `conv`):
The downsample type for downsampling layers. Choose between "conv" and "resnet"
upsample_type (`str`, *optional*, defaults to `conv`):
The upsample type for upsampling layers. Choose between "conv" and "resnet"
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
given number of groups. If left as `None`, the group norm layer will only be created if
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
conditioning with `class_embed_type` equal to `None`.
"""
@register_to_config
def __init__(
self,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int] = (224, 448, 672, 896),
layers_per_block: int = 2,
mid_block_scale_factor: float = 1,
downsample_padding: int = 1,
downsample_type: str = "conv",
upsample_type: str = "conv",
dropout: float = 0.0,
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
attn_norm_num_groups: Optional[int] = None,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
num_train_timesteps: Optional[int] = None,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
# input
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
elif time_embedding_type == "learned":
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
dropout=dropout,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
r"""
The [`UNet2DModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d.UNet2DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
elif self.class_embedding is None and class_labels is not None:
raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(
hidden_states=sample, temb=emb, skip_sample=skip_sample
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if self.config.time_embedding_type == "fourier":
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
if not return_dict:
return (sample,)
return UNet2DOutput(sample=sample)
File diff suppressed because it is too large Load Diff
@@ -15,8 +15,8 @@
import flax.linen as nn
import jax.numpy as jnp
from .attention_flax import FlaxTransformer2DModel
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
class FlaxCrossAttnDownBlock2D(nn.Module):
File diff suppressed because it is too large Load Diff
@@ -19,10 +19,10 @@ import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .modeling_flax_utils import FlaxModelMixin
from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...utils import BaseOutput
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from .unet_2d_blocks_flax import (
FlaxCrossAttnDownBlock2D,
FlaxCrossAttnUpBlock2D,
@@ -342,14 +342,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
# 1. time
@@ -17,19 +17,19 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from ..utils import is_torch_version
from ..utils.torch_utils import apply_freeu
from .attention import Attention
from .dual_transformer_2d import DualTransformer2DModel
from .resnet import (
from ...utils import is_torch_version
from ...utils.torch_utils import apply_freeu
from ..attention import Attention
from ..dual_transformer_2d import DualTransformer2DModel
from ..resnet import (
Downsample2D,
ResnetBlock2D,
SpatioTemporalResBlock,
TemporalConvLayer,
Upsample2D,
)
from .transformer_2d import Transformer2DModel
from .transformer_temporal import (
from ..transformer_2d import Transformer2DModel
from ..transformer_temporal import (
TransformerSpatioTemporalModel,
TransformerTemporalModel,
)
@@ -20,20 +20,20 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, deprecate, logging
from .activations import get_activation
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, logging
from ..activations import get_activation
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
@@ -284,7 +284,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
)
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -308,7 +308,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
Enable sliced attention computation.
@@ -374,7 +374,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -449,7 +449,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -469,7 +469,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -494,7 +494,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
@@ -503,7 +503,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unload_lora
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
def unload_lora(self):
"""Unloads LoRA weights."""
deprecate(
@@ -19,11 +19,11 @@ import torch
import torch.utils.checkpoint
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import Attention, AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -17,19 +17,19 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import logging
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import logging
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformer_temporal import TransformerTemporalModel
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_blocks import (
@@ -524,7 +524,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
)
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -548,7 +548,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -583,7 +583,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
@@ -613,7 +613,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self) -> None:
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
@@ -625,7 +625,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self) -> None:
"""
Disables custom attention processors and sets the default attention implementation.
@@ -645,7 +645,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
@@ -670,7 +670,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
def disable_freeu(self) -> None:
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
@@ -4,12 +4,12 @@ from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, logging
from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
@@ -323,7 +323,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
@@ -20,20 +20,20 @@ import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import PeftAdapterMixin
from .attention import BasicTransformerBlock, SkipFFTransformerBlock
from .attention_processor import (
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, get_timestep_embedding
from .modeling_utils import ModelMixin
from .normalization import GlobalResponseNorm, RMSNorm
from .resnet import Downsample2D, Upsample2D
from ..embeddings import TimestepEmbedding, get_timestep_embedding
from ..modeling_utils import ModelMixin
from ..normalization import GlobalResponseNorm, RMSNorm
from ..resnet import Downsample2D, Upsample2D
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
@@ -213,7 +213,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return logits
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -237,7 +237,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -272,7 +272,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -13,18 +13,20 @@
# limitations under the License.
import inspect
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.fft as fft
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unet_motion_model import MotionAdapter
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
@@ -36,6 +38,7 @@ from ...schedulers import (
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -64,10 +67,7 @@ EXAMPLE_DOC_STRING = """
"""
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
@@ -76,9 +76,83 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
return outputs
def _get_freeinit_freq_filter(
shape: Tuple[int, ...],
device: Union[str, torch.dtype],
filter_type: str,
order: float,
spatial_stop_frequency: float,
temporal_stop_frequency: float,
) -> torch.Tensor:
r"""Returns the FreeInit filter based on filter type and other input conditions."""
T, H, W = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)
if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
return mask
if filter_type == "butterworth":
def retrieve_mask(x):
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
elif filter_type == "gaussian":
def retrieve_mask(x):
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
elif filter_type == "ideal":
def retrieve_mask(x):
return 1 if x <= spatial_stop_frequency * 2 else 0
else:
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
for t in range(T):
for h in range(H):
for w in range(W):
d_square = (
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / T - 1)) ** 2
+ (2 * h / H - 1) ** 2
+ (2 * w / W - 1) ** 2
)
mask[..., t, h, w] = retrieve_mask(d_square)
return mask.to(device)
def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor:
r"""Noise reinitialization."""
# FFT
x_freq = fft.fftn(x, dim=(-3, -2, -1))
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
# frequency mix
HPF = 1 - LPF
x_freq_low = x_freq * LPF
noise_freq_high = noise_freq * HPF
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
# IFFT
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
return x_mixed
@dataclass
class AnimateDiffPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]
@@ -115,6 +189,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -442,6 +517,58 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@property
def free_init_enabled(self):
return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None
def enable_free_init(
self,
num_iters: int = 3,
use_fast_sampling: bool = False,
method: str = "butterworth",
order: int = 4,
spatial_stop_frequency: float = 0.25,
temporal_stop_frequency: float = 0.25,
generator: torch.Generator = None,
):
"""Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
Args:
num_iters (`int`, *optional*, defaults to `3`):
Number of FreeInit noise re-initialization iterations.
use_fast_sampling (`bool`, *optional*, defaults to `False`):
Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
method (`str`, *optional*, defaults to `butterworth`):
Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
FreeInit low pass filter.
order (`int`, *optional*, defaults to `4`):
Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
whereas lower values lead to `gaussian` method behaviour.
spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
the original implementation.
temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
the original implementation.
generator (`torch.Generator`, *optional*, defaults to `0.25`):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
FreeInit generation deterministic.
"""
self._free_init_num_iters = num_iters
self._free_init_use_fast_sampling = use_fast_sampling
self._free_init_method = method
self._free_init_order = order
self._free_init_spatial_stop_frequency = spatial_stop_frequency
self._free_init_temporal_stop_frequency = temporal_stop_frequency
self._free_init_generator = generator
def disable_free_init(self):
"""Disables the FreeInit mechanism if enabled."""
self._free_init_num_iters = None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -539,6 +666,181 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
latents = latents * self.scheduler.init_noise_sigma
return latents
def _denoise_loop(
self,
timesteps,
num_inference_steps,
do_classifier_free_guidance,
guidance_scale,
num_warmup_steps,
prompt_embeds,
negative_prompt_embeds,
latents,
cross_attention_kwargs,
added_cond_kwargs,
extra_step_kwargs,
callback,
callback_steps,
callback_on_step_end,
callback_on_step_end_tensor_inputs,
):
"""Denoising loop for AnimateDiff."""
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
return latents
def _free_init_loop(
self,
height,
width,
num_frames,
num_channels_latents,
batch_size,
num_videos_per_prompt,
denoise_args,
device,
):
"""Denoising loop for AnimateDiff using FreeInit noise reinitialization technique."""
latents = denoise_args.get("latents")
prompt_embeds = denoise_args.get("prompt_embeds")
timesteps = denoise_args.get("timesteps")
num_inference_steps = denoise_args.get("num_inference_steps")
latent_shape = (
batch_size * num_videos_per_prompt,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
free_init_filter_shape = (
1,
num_channels_latents,
num_frames,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
free_init_freq_filter = _get_freeinit_freq_filter(
shape=free_init_filter_shape,
device=device,
filter_type=self._free_init_method,
order=self._free_init_order,
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
)
with self.progress_bar(total=self._free_init_num_iters) as free_init_progress_bar:
for i in range(self._free_init_num_iters):
# For the first FreeInit iteration, the original latent is used without modification.
# Subsequent iterations apply the noise reinitialization technique.
if i == 0:
initial_noise = latents.detach().clone()
else:
current_diffuse_timestep = (
self.scheduler.config.num_train_timesteps - 1
) # diffuse to t=999 noise level
diffuse_timesteps = torch.full((batch_size,), current_diffuse_timestep).long()
z_T = self.scheduler.add_noise(
original_samples=latents, noise=initial_noise, timesteps=diffuse_timesteps.to(device)
).to(dtype=torch.float32)
z_rand = randn_tensor(
shape=latent_shape,
generator=self._free_init_generator,
device=device,
dtype=torch.float32,
)
latents = _freq_mix_3d(z_T, z_rand, LPF=free_init_freq_filter)
latents = latents.to(prompt_embeds.dtype)
# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
if self._free_init_use_fast_sampling:
current_num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (i + 1))
self.scheduler.set_timesteps(current_num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
denoise_args.update({"timesteps": timesteps, "num_inference_steps": current_num_inference_steps})
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
denoise_args.update({"latents": latents, "num_warmup_steps": num_warmup_steps})
latents = self._denoise_loop(**denoise_args)
free_init_progress_bar.update()
return latents
def _retrieve_video_frames(self, latents, output_type, return_dict):
"""Helper function to handle latents to output conversion."""
if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
if not return_dict:
return (video,)
return AnimateDiffPipelineOutput(frames=video)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -559,10 +861,11 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -603,25 +906,30 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead
of a plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
Returns:
@@ -629,6 +937,23 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -637,9 +962,20 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -649,30 +985,26 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
do_classifier_free_guidance,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
@@ -680,12 +1012,13 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_videos_per_prompt, output_hidden_state
)
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -703,55 +1036,47 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7 Add image embeds for IP-Adapter
# 7. Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# Denoising loop
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
denoise_args = {
"timesteps": timesteps,
"num_inference_steps": num_inference_steps,
"do_classifier_free_guidance": self.do_classifier_free_guidance,
"guidance_scale": guidance_scale,
"num_warmup_steps": num_warmup_steps,
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"latents": latents,
"cross_attention_kwargs": self.cross_attention_kwargs,
"added_cond_kwargs": added_cond_kwargs,
"extra_step_kwargs": extra_step_kwargs,
"callback": callback,
"callback_steps": callback_steps,
"callback_on_step_end": callback_on_step_end,
"callback_on_step_end_tensor_inputs": callback_on_step_end_tensor_inputs,
}
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)
# Post-processing
video_tensor = self.decode_latents(latents)
if output_type == "pt":
video = video_tensor
if self.free_init_enabled:
latents = self._free_init_loop(
height=height,
width=width,
num_frames=num_frames,
num_channels_latents=num_channels_latents,
batch_size=batch_size,
num_videos_per_prompt=num_videos_per_prompt,
denoise_args=denoise_args,
device=device,
)
else:
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
latents = self._denoise_loop(**denoise_args)
# Offload all models
video = self._retrieve_video_frames(latents, output_type, return_dict)
# 9. Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return AnimateDiffPipelineOutput(frames=video)
return video
@@ -36,8 +36,8 @@ from ...models.embeddings import (
from ...models.modeling_utils import ModelMixin
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
from ...utils import BaseOutput, is_torch_version, logging
@@ -513,7 +513,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
)
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
@@ -537,7 +537,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
@@ -572,7 +572,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -588,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
self.set_attn_processor(processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
@@ -654,7 +654,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@@ -687,7 +687,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
@@ -700,8 +700,8 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
which adds large negative values to the attention scores corresponding to "discard" tokens.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
@@ -603,15 +603,6 @@ class StableDiffusionControlNetPipeline(
f" {negative_prompt_embeds.shape}."
)
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)
# Check `image`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
@@ -633,7 +624,13 @@ class StableDiffusionControlNetPipeline(
# When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
transposed_image = [list(t) for t in zip(*image)]
if len(transposed_image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets."
)
for image_ in transposed_image:
self.check_image(image_, prompt, prompt_embeds)
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
@@ -659,7 +656,10 @@ class StableDiffusionControlNetPipeline(
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
raise ValueError(
"A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
"The conditioning scale must be fixed across the batch."
)
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
@@ -906,7 +906,9 @@ class StableDiffusionControlNetPipeline(
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
`init`, images must be passed as a list such that each element of the list can be correctly batched for
input to a single ControlNet.
input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet,
each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets,
where a list of image lists can be passed to batch for each prompt and each ControlNet.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
@@ -1105,6 +1107,11 @@ class StableDiffusionControlNetPipeline(
elif isinstance(controlnet, MultiControlNetModel):
images = []
# Nested lists as ControlNet condition
if isinstance(image[0], list):
# Transpose the nested image list
image = [list(t) for t in zip(*image)]
for image_ in image:
image_ = self.prepare_image(
image=image_,
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -1087,7 +1087,10 @@ class StableDiffusionControlNetImg2ImgPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
@@ -683,9 +683,11 @@ class StableDiffusionControlNetInpaintPipeline(
self,
prompt,
image,
mask_image,
height,
width,
callback_steps,
output_type,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
@@ -693,6 +695,7 @@ class StableDiffusionControlNetInpaintPipeline(
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
):
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -736,6 +739,19 @@ class StableDiffusionControlNetInpaintPipeline(
f" {negative_prompt_embeds.shape}."
)
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
)
if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetModel):
@@ -862,7 +878,6 @@ class StableDiffusionControlNetInpaintPipeline(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def prepare_control_image(
self,
image,
@@ -872,10 +887,14 @@ class StableDiffusionControlNetInpaintPipeline(
num_images_per_prompt,
device,
dtype,
crops_coords,
resize_mode,
do_classifier_free_guidance=False,
guess_mode=False,
):
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image = self.control_image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
@@ -1074,6 +1093,7 @@ class StableDiffusionControlNetInpaintPipeline(
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
@@ -1130,6 +1150,12 @@ class StableDiffusionControlNetInpaintPipeline(
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
padding_mask_crop (`int`, *optional*, defaults to `None`):
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
and contain information inreleant for inpainging, such as background.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
@@ -1240,9 +1266,11 @@ class StableDiffusionControlNetInpaintPipeline(
self.check_inputs(
prompt,
control_image,
mask_image,
height,
width,
callback_steps,
output_type,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
@@ -1250,6 +1278,7 @@ class StableDiffusionControlNetInpaintPipeline(
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
self._guidance_scale = guidance_scale
@@ -1264,6 +1293,14 @@ class StableDiffusionControlNetInpaintPipeline(
else:
batch_size = prompt_embeds.shape[0]
if padding_mask_crop is not None:
height, width = self.image_processor.get_default_height_width(image, height, width)
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
resize_mode = "fill"
else:
crops_coords = None
resize_mode = "default"
device = self._execution_device
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
@@ -1315,6 +1352,8 @@ class StableDiffusionControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1330,6 +1369,8 @@ class StableDiffusionControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1341,10 +1382,15 @@ class StableDiffusionControlNetInpaintPipeline(
assert False
# 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
init_image = self.image_processor.preprocess(image, height=height, width=width)
original_image = image
init_image = self.image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
)
init_image = init_image.to(dtype=torch.float32)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
mask = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
masked_image = init_image * (mask < 0.5)
_, _, height, width = init_image.shape
@@ -1534,6 +1580,9 @@ class StableDiffusionControlNetInpaintPipeline(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if padding_mask_crop is not None:
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
# Offload all models
self.maybe_free_model_hooks()
@@ -557,9 +557,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
prompt,
prompt_2,
image,
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
@@ -570,6 +572,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -632,6 +635,19 @@ class StableDiffusionXLControlNetInpaintPipeline(
f" {negative_prompt_embeds.shape}."
)
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
)
if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
@@ -745,10 +761,14 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_images_per_prompt,
device,
dtype,
crops_coords,
resize_mode,
do_classifier_free_guidance=False,
guess_mode=False,
):
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image = self.control_image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
@@ -1066,6 +1086,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
] = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
strength: float = 0.9999,
num_inference_steps: int = 50,
denoising_start: Optional[float] = None,
@@ -1121,6 +1142,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
padding_mask_crop (`int`, *optional*, defaults to `None`):
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
and contain information inreleant for inpainging, such as background.
strength (`float`, *optional*, defaults to 0.9999):
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
@@ -1290,9 +1317,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
prompt,
prompt_2,
control_image,
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
@@ -1303,6 +1332,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
self._guidance_scale = guidance_scale
@@ -1370,7 +1400,18 @@ class StableDiffusionXLControlNetInpaintPipeline(
# 5. Preprocess mask and image - resizes image and mask w.r.t height and width
# 5.1 Prepare init image
init_image = self.image_processor.preprocess(image, height=height, width=width)
if padding_mask_crop is not None:
height, width = self.image_processor.get_default_height_width(image, height, width)
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
resize_mode = "fill"
else:
crops_coords = None
resize_mode = "default"
original_image = image
init_image = self.image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
)
init_image = init_image.to(dtype=torch.float32)
# 5.2 Prepare control images
@@ -1383,6 +1424,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1398,6 +1441,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1409,7 +1454,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
raise ValueError(f"{controlnet.__class__} is not supported.")
# 5.3 Prepare mask
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
mask = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
masked_image = init_image * (mask < 0.5)
_, _, height, width = init_image.shape
@@ -1684,6 +1731,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
image = self.image_processor.postprocess(image, output_type=output_type)
if padding_mask_crop is not None:
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
# Offload all models
self.maybe_free_model_hooks()
@@ -33,7 +33,7 @@ from ....models.embeddings import (
)
from ....models.resnet import ResnetBlockCondNorm2D
from ....models.transformer_2d import Transformer2DModel
from ....models.unet_2d_condition import UNet2DConditionOutput
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ....utils.torch_utils import apply_freeu
@@ -268,7 +268,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
return objs
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
@@ -1096,7 +1096,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
@@ -1112,8 +1112,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
@@ -1786,7 +1786,7 @@ class CrossAttnDownBlockFlat(nn.Module):
return hidden_states, output_states
# Copied from diffusers.models.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
class UpBlockFlat(nn.Module):
def __init__(
self,
@@ -1897,7 +1897,7 @@ class UpBlockFlat(nn.Module):
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
class CrossAttnUpBlockFlat(nn.Module):
def __init__(
self,
@@ -2071,7 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlat(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
@@ -2227,7 +2227,7 @@ class UNetMidBlockFlat(nn.Module):
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module):
def __init__(
self,
@@ -2374,7 +2374,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
def __init__(
self,
+7 -2
View File
@@ -351,7 +351,7 @@ def get_class_obj_and_candidates(
def _get_pipeline_class(
class_obj,
config,
config=None,
load_connected_pipeline=False,
custom_pipeline=None,
repo_id=None,
@@ -389,7 +389,12 @@ def _get_pipeline_class(
return class_obj
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
class_name = config["_class_name"]
class_name = class_name or config["_class_name"]
if not class_name:
raise ValueError(
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
)
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
pipeline_cls = getattr(diffusers_module, class_name)
@@ -21,6 +21,7 @@ from typing import Dict, Optional, Union
import requests
import torch
import yaml
from transformers import (
AutoFeatureExtractor,
BertTokenizerFast,
@@ -50,8 +51,7 @@ from ...schedulers import (
PNDMScheduler,
UnCLIPScheduler,
)
from ...utils import is_accelerate_available, is_omegaconf_available, logging
from ...utils.import_utils import BACKENDS_MAPPING
from ...utils import is_accelerate_available, logging
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder
from ..pipeline_utils import DiffusionPipeline
@@ -237,51 +237,54 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
Creates a config for the diffusers based on the config of the LDM model.
"""
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
else:
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
unet_params = original_config.model.params.unet_config.params
if (
"unet_config" in original_config["model"]["params"]
and original_config["model"]["params"]["unet_config"] is not None
):
unet_params = original_config["model"]["params"]["unet_config"]["params"]
else:
unet_params = original_config.model.params.network_config.params
unet_params = original_config["model"]["params"]["network_config"]["params"]
vae_params = original_config.model.params.first_stage_config.params.ddconfig
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
if unet_params.transformer_depth is not None:
if unet_params["transformer_depth"] is not None:
transformer_layers_per_block = (
unet_params.transformer_depth
if isinstance(unet_params.transformer_depth, int)
else list(unet_params.transformer_depth)
unet_params["transformer_depth"]
if isinstance(unet_params["transformer_depth"], int)
else list(unet_params["transformer_depth"])
)
else:
transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
class_embed_type = None
addition_embed_type = None
@@ -289,13 +292,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
projection_class_embeddings_input_dim = None
context_dim = None
if unet_params.context_dim is not None:
if unet_params["context_dim"] is not None:
context_dim = (
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
unet_params["context_dim"]
if isinstance(unet_params["context_dim"], int)
else unet_params["context_dim"][0]
)
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
if unet_params["num_classes"] == "sequential":
if context_dim in [2048, 1280]:
# SDXL
addition_embed_type = "text_time"
@@ -303,14 +308,14 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
else:
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"in_channels": unet_params["in_channels"],
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
@@ -322,15 +327,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
}
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
config["only_cross_attention"] = unet_params["disable_self_attentions"]
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
config["num_class_embeds"] = unet_params.num_classes
if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
config["num_class_embeds"] = unet_params["num_classes"]
if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
config["conditioning_channels"] = unet_params["hint_channels"]
else:
config["out_channels"] = unet_params.out_channels
config["out_channels"] = unet_params["out_channels"]
config["up_block_types"] = tuple(up_block_types)
return config
@@ -340,38 +345,38 @@ def create_vae_diffusers_config(original_config, image_size: int):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config.model.params.first_stage_config.params.ddconfig
_ = original_config.model.params.first_stage_config.params.embed_dim
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = {
"sample_size": image_size,
"in_channels": vae_params.in_channels,
"out_channels": vae_params.out_ch,
"in_channels": vae_params["in_channels"],
"out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels,
"layers_per_block": vae_params.num_res_blocks,
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
}
return config
def create_diffusers_schedular(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
num_train_timesteps=original_config["model"]["params"]["timesteps"],
beta_start=original_config["model"]["params"]["linear_start"],
beta_end=original_config["model"]["params"]["linear_end"],
beta_schedule="scaled_linear",
)
return schedular
def create_ldm_bert_config(original_config):
bert_params = original_config.model.params.cond_stage_config.params
bert_params = original_config["model"]["params"]["cond_stage_config"]["params"]
config = LDMBertConfig(
d_model=bert_params.n_embed,
encoder_layers=bert_params.n_layer,
@@ -1006,9 +1011,9 @@ def stable_unclip_image_encoder(original_config, local_files_only=False):
encoders.
"""
image_embedder_config = original_config.model.params.embedder_config
image_embedder_config = original_config["model"]["params"]["embedder_config"]
sd_clip_image_embedder_class = image_embedder_config.target
sd_clip_image_embedder_class = image_embedder_config["target"]
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
if sd_clip_image_embedder_class == "ClipImageEmbedder":
@@ -1047,8 +1052,8 @@ def stable_unclip_image_noising_components(
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
"""
noise_aug_config = original_config.model.params.noise_aug_config
noise_aug_class = noise_aug_config.target
noise_aug_config = original_config["model"]["params"]["noise_aug_config"]
noise_aug_class = noise_aug_config["target"]
noise_aug_class = noise_aug_class.split(".")[-1]
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
@@ -1245,11 +1250,6 @@ def download_from_original_stable_diffusion_ckpt(
if prediction_type == "v-prediction":
prediction_type = "v_prediction"
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if isinstance(checkpoint_path_or_dict, str):
if from_safetensors:
from safetensors.torch import load_file as safe_load
@@ -1317,19 +1317,22 @@ def download_from_original_stable_diffusion_ckpt(
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()
original_config = OmegaConf.load(original_config_file)
original_config = yaml.safe_load(original_config_file)
# Convert the text model.
if (
model_type is None
and "cond_stage_config" in original_config.model.params
and original_config.model.params.cond_stage_config is not None
and "cond_stage_config" in original_config["model"]["params"]
and original_config["model"]["params"]["cond_stage_config"] is not None
):
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
elif model_type is None and original_config.model.params.network_config is not None:
if original_config.model.params.network_config.params.context_dim == 2048:
elif model_type is None and original_config["model"]["params"]["network_config"] is not None:
if original_config["model"]["params"]["network_config"]["params"]["context_dim"] == 2048:
model_type = "SDXL"
else:
model_type = "SDXL-Refiner"
@@ -1354,7 +1357,7 @@ def download_from_original_stable_diffusion_ckpt(
elif num_in_channels is None:
num_in_channels = 4
if "unet_config" in original_config.model.params:
if "unet_config" in original_config["model"]["params"]:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if (
@@ -1375,13 +1378,16 @@ def download_from_original_stable_diffusion_ckpt(
if image_size is None:
image_size = 512
if controlnet is None and "control_stage_config" in original_config.model.params:
if controlnet is None and "control_stage_config" in original_config["model"]["params"]:
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
controlnet = convert_controlnet_checkpoint(
checkpoint, original_config, path, image_size, upcast_attention, extract_ema
)
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
if "timesteps" in original_config["model"]["params"]:
num_train_timesteps = original_config["model"]["params"]["timesteps"]
else:
num_train_timesteps = 1000
if model_type in ["SDXL", "SDXL-Refiner"]:
scheduler_dict = {
@@ -1400,8 +1406,15 @@ def download_from_original_stable_diffusion_ckpt(
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
scheduler_type = "euler"
else:
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
if "linear_start" in original_config["model"]["params"]:
beta_start = original_config["model"]["params"]["linear_start"]
else:
beta_start = 0.02
if "linear_end" in original_config["model"]["params"]:
beta_end = original_config["model"]["params"]["linear_end"]
else:
beta_end = 0.085
scheduler = DDIMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
@@ -1435,7 +1448,7 @@ def download_from_original_stable_diffusion_ckpt(
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
if pipeline_class == StableDiffusionUpscalePipeline:
image_size = original_config.model.params.unet_config.params.image_size
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
@@ -1464,10 +1477,10 @@ def download_from_original_stable_diffusion_ckpt(
if (
"model" in original_config
and "params" in original_config.model
and "scale_factor" in original_config.model.params
and "params" in original_config["model"]
and "scale_factor" in original_config["model"]["params"]
):
vae_scaling_factor = original_config.model.params.scale_factor
vae_scaling_factor = original_config["model"]["params"]["scale_factor"]
else:
vae_scaling_factor = 0.18215 # default SD scaling factor
@@ -1803,11 +1816,6 @@ def download_controlnet_from_original_ckpt(
use_linear_projection: Optional[bool] = None,
cross_attention_dim: Optional[bool] = None,
) -> DiffusionPipeline:
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if from_safetensors:
from safetensors import safe_open
@@ -1827,12 +1835,12 @@ def download_controlnet_from_original_ckpt(
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
original_config = OmegaConf.load(original_config_file)
original_config = yaml.safe_load(original_config_file)
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if "control_stage_config" not in original_config.model.params:
if "control_stage_config" not in original_config["model"]["params"]:
raise ValueError("`control_stage_config` not present in original config")
controlnet = convert_controlnet_checkpoint(
@@ -642,6 +642,7 @@ class StableDiffusionInpaintPipeline(
width,
strength,
callback_steps,
output_type,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
@@ -693,11 +694,6 @@ class StableDiffusionInpaintPipeline(
f" {negative_prompt_embeds.shape}."
)
if padding_mask_crop is not None:
if self.unet.config.in_channels != 4:
raise ValueError(
f"The UNet should have 4 input channels for inpainting mask crop, but has"
f" {self.unet.config.in_channels} input channels."
)
if not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
@@ -707,6 +703,8 @@ class StableDiffusionInpaintPipeline(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
)
if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
def prepare_latents(
self,
@@ -1166,6 +1164,7 @@ class StableDiffusionInpaintPipeline(
width,
strength,
callback_steps,
output_type,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
@@ -744,15 +744,19 @@ class StableDiffusionXLInpaintPipeline(
self,
prompt,
prompt_2,
image,
mask_image,
height,
width,
strength,
callback_steps,
output_type,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -810,6 +814,18 @@ class StableDiffusionXLInpaintPipeline(
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
f" {type(mask_image)}."
)
if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
def prepare_latents(
self,
@@ -1225,6 +1241,7 @@ class StableDiffusionXLInpaintPipeline(
masked_image_latents: torch.FloatTensor = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
strength: float = 0.9999,
num_inference_steps: int = 50,
timesteps: List[int] = None,
@@ -1287,6 +1304,12 @@ class StableDiffusionXLInpaintPipeline(
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
padding_mask_crop (`int`, *optional*, defaults to `None`):
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
and contain information inreleant for inpainging, such as background.
strength (`float`, *optional*, defaults to 0.9999):
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
@@ -1449,15 +1472,19 @@ class StableDiffusionXLInpaintPipeline(
self.check_inputs(
prompt,
prompt_2,
image,
mask_image,
height,
width,
strength,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
self._guidance_scale = guidance_scale
@@ -1527,10 +1554,22 @@ class StableDiffusionXLInpaintPipeline(
is_strength_max = strength == 1.0
# 5. Preprocess mask and image
init_image = self.image_processor.preprocess(image, height=height, width=width)
if padding_mask_crop is not None:
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
resize_mode = "fill"
else:
crops_coords = None
resize_mode = "default"
original_image = image
init_image = self.image_processor.preprocess(
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
)
init_image = init_image.to(dtype=torch.float32)
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
mask = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
if masked_image_latents is not None:
masked_image = masked_image_latents
@@ -1791,6 +1830,9 @@ class StableDiffusionXLInpaintPipeline(
image = self.image_processor.postprocess(image, output_type=output_type)
if padding_mask_crop is not None:
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
# Offload all models
self.maybe_free_model_hooks()
@@ -858,7 +858,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
)
# 4. Preprocess image
image = self.image_processor.preprocess(image).to(device)
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -40,10 +40,8 @@ def _append_dims(x, target_dims):
return x[(...,) + (None,) * dims_to_append]
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
@@ -52,6 +50,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
return outputs

Some files were not shown because too many files have changed in this diff Show More