Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3e93944bf7 |
@@ -25,7 +25,7 @@ jobs:
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
|
||||
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# Checkout to correct ref
|
||||
# If workflow dispatch
|
||||
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
@@ -16,7 +16,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
check_dependencies:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
|
||||
@@ -16,7 +16,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
check_flax_dependencies:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
|
||||
@@ -20,7 +20,7 @@ env:
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
|
||||
@@ -29,7 +29,7 @@ env:
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
@@ -49,7 +49,7 @@ jobs:
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
|
||||
@@ -16,7 +16,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
check_torch_dependencies:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
|
||||
@@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
find-and-checkout-latest-branch:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
|
||||
steps:
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
|
||||
release:
|
||||
needs: find-and-checkout-latest-branch
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
|
||||
@@ -8,7 +8,7 @@ jobs:
|
||||
close_stale_issues:
|
||||
name: Close Stale Issues
|
||||
if: github.repository == 'huggingface/diffusers'
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
@@ -5,7 +5,7 @@ name: Secret Leaks
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -5,7 +5,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-22.04
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
@@ -3,7 +3,7 @@ import sys
|
||||
|
||||
import pandas as pd
|
||||
from huggingface_hub import hf_hub_download, upload_file
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from huggingface_hub.utils._errors import EntryNotFoundError
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
@@ -75,8 +75,6 @@
|
||||
title: Outpainting
|
||||
title: Advanced inference
|
||||
- sections:
|
||||
- local: using-diffusers/cogvideox
|
||||
title: CogVideoX
|
||||
- local: using-diffusers/sdxl
|
||||
title: Stable Diffusion XL
|
||||
- local: using-diffusers/sdxl_turbo
|
||||
@@ -131,8 +129,6 @@
|
||||
title: T2I-Adapters
|
||||
- local: training/instructpix2pix
|
||||
title: InstructPix2Pix
|
||||
- local: training/cogvideox
|
||||
title: CogVideoX
|
||||
title: Models
|
||||
- isExpanded: false
|
||||
sections:
|
||||
@@ -246,8 +242,6 @@
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
title: CogVideoXTransformer3DModel
|
||||
- local: api/models/cogview3plus_transformer2d
|
||||
title: CogView3PlusTransformer2DModel
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
@@ -326,8 +320,6 @@
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/cogview3
|
||||
title: CogView3
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# CogView3PlusTransformer2DModel
|
||||
|
||||
A Diffusion Transformer model for 2D data from [CogView3Plus](https://github.com/THUDM/CogView3) was introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) by Tsinghua University & ZhipuAI.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import CogView3PlusTransformer2DModel
|
||||
|
||||
vae = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## CogView3PlusTransformer2DModel
|
||||
|
||||
[[autodoc]] CogView3PlusTransformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -36,10 +36,6 @@ There are two models available that can be used with the text-to-video and video
|
||||
There is one model available that can be used with the image-to-video CogVideoX pipeline:
|
||||
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.
|
||||
|
||||
There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team):
|
||||
- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`.
|
||||
- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`.
|
||||
|
||||
## Inference
|
||||
|
||||
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
|
||||
@@ -122,12 +118,6 @@ It is also worth noting that torchao quantization is fully compatible with [torc
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CogVideoXFunControlPipeline
|
||||
|
||||
[[autodoc]] CogVideoXFunControlPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CogVideoXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# CogView3Plus
|
||||
|
||||
[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Recent advancements in text-to-image generative systems have been largely driven by diffusion models. However, single-stage text-to-image diffusion models still face challenges, in terms of computational efficiency and the refinement of image details. To tackle the issue, we propose CogView3, an innovative cascaded framework that enhances the performance of text-to-image diffusion. CogView3 is the first model implementing relay diffusion in the realm of text-to-image generation, executing the task by first creating low-resolution images and subsequently applying relay-based super-resolution. This methodology not only results in competitive text-to-image outputs but also greatly reduces both training and inference costs. Our experimental results demonstrate that CogView3 outperforms SDXL, the current state-of-the-art open-source text-to-image diffusion model, by 77.0% in human evaluations, all while requiring only about 1/2 of the inference time. The distilled variant of CogView3 achieves comparable performance while only utilizing 1/10 of the inference time by SDXL.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
|
||||
|
||||
## CogView3PlusPipeline
|
||||
|
||||
[[autodoc]] CogView3PlusPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CogView3PipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cogview3.pipeline_output.CogView3PipelineOutput
|
||||
@@ -53,16 +53,8 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPAGImg2ImgPipeline
|
||||
[[autodoc]] StableDiffusionPAGImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionControlNetPAGPipeline
|
||||
[[autodoc]] StableDiffusionControlNetPAGPipeline
|
||||
|
||||
## StableDiffusionControlNetPAGInpaintPipeline
|
||||
[[autodoc]] StableDiffusionControlNetPAGInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
@@ -52,7 +52,6 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso
|
||||
| sgm_uniform | init with `timestep_spacing="trailing"` |
|
||||
| simple | init with `timestep_spacing="trailing"` |
|
||||
| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` |
|
||||
| beta | init with `timestep_spacing="linspace"`, `use_beta_sigmas=True` |
|
||||
|
||||
All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ Optimization orthogonal to parallelization focuses on accelerating single GPU pe
|
||||
The overview of xDiT is shown as follows.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/methods/xdit_overview.png">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/methods/xdit_overview.png">
|
||||
</div>
|
||||
You can install xDiT using the following command:
|
||||
|
||||
@@ -78,36 +78,37 @@ A subset of Diffusers models are supported in xDiT, such as Flux.1, Stable Diffu
|
||||
## Benchmark
|
||||
We tested different models on various machines, and here is some of the benchmark data.
|
||||
|
||||
|
||||
### Flux.1-schnell
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/flux/Flux-2k-L40.png">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/flux/Flux-2k-L40.png">
|
||||
</div>
|
||||
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/flux/Flux-2K-A100.png">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/flux/Flux-2K-A100.png">
|
||||
</div>
|
||||
|
||||
### Stable Diffusion 3
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/sd3/L40-SD3.png">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/sd3/L40-SD3.png">
|
||||
</div>
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/sd3/A100-SD3.png">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/sd3/A100-SD3.png">
|
||||
</div>
|
||||
|
||||
### HunyuanDiT
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/L40-HunyuanDiT.png">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/hunuyuandit/L40-HunyuanDiT.png">
|
||||
</div>
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/V100-HunyuanDiT.png">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/hunuyuandit/A100-HunyuanDiT.png">
|
||||
</div>
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/xDiT/documentation-images/resolve/main/performance/hunuyuandit/T4-HunyuanDiT.png">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/hunuyuandit/T4-HunyuanDiT.png">
|
||||
</div>
|
||||
|
||||
More detailed performance metric can be found on our [github page](https://github.com/xdit-project/xDiT?tab=readme-ov-file#perf).
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
# CogVideoX
|
||||
|
||||
CogVideoX is a text-to-video generation model focused on creating more coherent videos aligned with a prompt. It achieves this using several methods.
|
||||
|
||||
- a 3D variational autoencoder that compresses videos spatially and temporally, improving compression rate and video accuracy.
|
||||
|
||||
- an expert transformer block to help align text and video, and a 3D full attention module for capturing and creating spatially and temporally accurate videos.
|
||||
|
||||
The actual test of the video instruction dimension found that CogVideoX has good effects on consistent theme, dynamic information, consistent background, object information, smooth motion, color, scene, appearance style, and temporal style but cannot achieve good results with human action, spatial relationship, and multiple objects.
|
||||
|
||||
Finetuning with Diffusers can help make up for these poor results.
|
||||
|
||||
## Data Preparation
|
||||
|
||||
The training scripts accepts data in two formats.
|
||||
|
||||
The first format is suited for small-scale training, and the second format uses a CSV format, which is more appropriate for streaming data for large-scale training. In the future, Diffusers will support the `<Video>` tag.
|
||||
|
||||
### Small format
|
||||
|
||||
Two files where one file contains line-separated prompts and another file contains line-separated paths to video data (the path to video files must be relative to the path you pass when specifying `--instance_data_root`). Let's take a look at an example to understand this better!
|
||||
|
||||
Assume you've specified `--instance_data_root` as `/dataset`, and that this directory contains the files: `prompts.txt` and `videos.txt`.
|
||||
|
||||
The `prompts.txt` file should contain line-separated prompts:
|
||||
|
||||
```
|
||||
A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.
|
||||
A black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.
|
||||
...
|
||||
```
|
||||
|
||||
The `videos.txt` file should contain line-separate paths to video files. Note that the path should be _relative_ to the `--instance_data_root` directory.
|
||||
|
||||
```
|
||||
videos/00000.mp4
|
||||
videos/00001.mp4
|
||||
...
|
||||
```
|
||||
|
||||
Overall, this is how your dataset would look like if you ran the `tree` command on the dataset root directory:
|
||||
|
||||
```
|
||||
/dataset
|
||||
├── prompts.txt
|
||||
├── videos.txt
|
||||
├── videos
|
||||
├── videos/00000.mp4
|
||||
├── videos/00001.mp4
|
||||
├── ...
|
||||
```
|
||||
|
||||
When using this format, the `--caption_column` must be `prompts.txt` and `--video_column` must be `videos.txt`.
|
||||
|
||||
### Stream format
|
||||
|
||||
You could use a single CSV file. For the sake of this example, assume you have a `metadata.csv` file. The expected format is:
|
||||
|
||||
```
|
||||
<CAPTION_COLUMN>,<PATH_TO_VIDEO_COLUMN>
|
||||
"""A black and white animated sequence featuring a rabbit, named Rabbity Ribfried, and an anthropomorphic goat in a musical, playful environment, showcasing their evolving interaction.""","""00000.mp4"""
|
||||
"""A black and white animated sequence on a ship's deck features a bulldog character, named Bully Bulldoger, showcasing exaggerated facial expressions and body language. The character progresses from confident to focused, then to strained and distressed, displaying a range of emotions as it navigates challenges. The ship's interior remains static in the background, with minimalistic details such as a bell and open door. The character's dynamic movements and changing expressions drive the narrative, with no camera movement to distract from its evolving reactions and physical gestures.""","""00001.mp4"""
|
||||
...
|
||||
```
|
||||
|
||||
In this case, the `--instance_data_root` should be the location where the videos are stored and `--dataset_name` should be either a path to local folder or a [`~datasets.load_dataset`] compatible dataset hosted on the Hub. Assuming you have videos of Minecraft gameplay at `https://huggingface.co/datasets/my-awesome-username/minecraft-videos`, you would have to specify `my-awesome-username/minecraft-videos`.
|
||||
|
||||
When using this format, the `--caption_column` must be `<CAPTION_COLUMN>` and `--video_column` must be `<PATH_TO_VIDEO_COLUMN>`.
|
||||
|
||||
You are not strictly restricted to the CSV format. Any format works as long as the `load_dataset` method supports the file format to load a basic `<PATH_TO_VIDEO_COLUMN>` and `<CAPTION_COLUMN>`. The reason for going through these dataset organization gymnastics for loading video data is because `load_dataset` does not fully support all kinds of video formats.
|
||||
|
||||
> [!NOTE]
|
||||
> CogVideoX works best with long and descriptive LLM-augmented prompts for video generation. We recommend pre-processing your videos by first generating a summary using a VLM and then augmenting the prompts with an LLM. To generate the above captions, we use [MiniCPM-V-26](https://huggingface.co/openbmb/MiniCPM-V-2_6) and [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct). A very barebones and no-frills example for this is available [here](https://gist.github.com/a-r-r-o-w/4dee20250e82f4e44690a02351324a4a). The official recommendation for augmenting prompts is [ChatGLM](https://huggingface.co/THUDM?search_models=chatglm) and a length of 50-100 words is considered good.
|
||||
|
||||
>![NOTE]
|
||||
> It is expected that your dataset is already pre-processed. If not, some basic pre-processing can be done by playing with the following parameters:
|
||||
> `--height`, `--width`, `--fps`, `--max_num_frames`, `--skip_frames_start` and `--skip_frames_end`.
|
||||
> Presently, all videos in your dataset should contain the same number of video frames when using a training batch size > 1.
|
||||
|
||||
<!-- TODO: Implement frame packing in future to address above issue. -->
|
||||
|
||||
## Training
|
||||
|
||||
You need to setup your development environment by installing the necessary requirements. The following packages are required:
|
||||
- Torch 2.0 or above based on the training features you are utilizing (might require latest or nightly versions for quantized/deepspeed training)
|
||||
- `pip install diffusers transformers accelerate peft huggingface_hub` for all things modeling and training related
|
||||
- `pip install datasets decord` for loading video training data
|
||||
- `pip install bitsandbytes` for using 8-bit Adam or AdamW optimizers for memory-optimized training
|
||||
- `pip install wandb` optionally for monitoring training logs
|
||||
- `pip install deepspeed` optionally for [DeepSpeed](https://github.com/microsoft/DeepSpeed) training
|
||||
- `pip install prodigyopt` optionally if you would like to use the Prodigy optimizer for training
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
Before running the script, make sure you install the library from source:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
||||
|
||||
Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
|
||||
|
||||
- PyTorch
|
||||
|
||||
```bash
|
||||
cd examples/cogvideo
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if you use torch.compile, there can be dramatic speedups. The PEFT library is used as a backend for LoRA training, so make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
If you would like to push your model to the Hub after training is completed with a neat model card, make sure you're logged in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
|
||||
# Alternatively, you could upload your model manually using:
|
||||
# huggingface-cli upload my-cool-account-name/my-cool-lora-name /path/to/awesome/lora
|
||||
```
|
||||
|
||||
Make sure your data is prepared as described in [Data Preparation](#data-preparation). When ready, you can begin training!
|
||||
|
||||
Assuming you are training on 50 videos of a similar concept, we have found 1500-2000 steps to work well. The official recommendation, however, is 100 videos with a total of 4000 steps. Assuming you are training on a single GPU with a `--train_batch_size` of `1`:
|
||||
- 1500 steps on 50 videos would correspond to `30` training epochs
|
||||
- 4000 steps on 100 videos would correspond to `40` training epochs
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
GPU_IDS="0"
|
||||
|
||||
accelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \
|
||||
--pretrained_model_name_or_path THUDM/CogVideoX-2b \
|
||||
--cache_dir <CACHE_DIR> \
|
||||
--instance_data_root <PATH_TO_WHERE_VIDEO_FILES_ARE_STORED> \
|
||||
--dataset_name my-awesome-name/my-awesome-dataset \
|
||||
--caption_column <CAPTION_COLUMN> \
|
||||
--video_column <PATH_TO_VIDEO_COLUMN> \
|
||||
--id_token <ID_TOKEN> \
|
||||
--validation_prompt "<ID_TOKEN> Spiderman swinging over buildings:::A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance" \
|
||||
--validation_prompt_separator ::: \
|
||||
--num_validation_videos 1 \
|
||||
--validation_epochs 10 \
|
||||
--seed 42 \
|
||||
--rank 64 \
|
||||
--lora_alpha 64 \
|
||||
--mixed_precision fp16 \
|
||||
--output_dir /raid/aryan/cogvideox-lora \
|
||||
--height 480 --width 720 --fps 8 --max_num_frames 49 --skip_frames_start 0 --skip_frames_end 0 \
|
||||
--train_batch_size 1 \
|
||||
--num_train_epochs 30 \
|
||||
--checkpointing_steps 1000 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--learning_rate 1e-3 \
|
||||
--lr_scheduler cosine_with_restarts \
|
||||
--lr_warmup_steps 200 \
|
||||
--lr_num_cycles 1 \
|
||||
--enable_slicing \
|
||||
--enable_tiling \
|
||||
--optimizer Adam \
|
||||
--adam_beta1 0.9 \
|
||||
--adam_beta2 0.95 \
|
||||
--max_grad_norm 1.0 \
|
||||
--report_to wandb
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
* `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Setting the `<ID_TOKEN>` is not necessary. From some limited experimentation, we found it works better (as it resembles [Dreambooth](https://huggingface.co/docs/diffusers/en/training/dreambooth) training) than without. When provided, the `<ID_TOKEN>` is appended to the beginning of each prompt. So, if your `<ID_TOKEN>` was `"DISNEY"` and your prompt was `"Spiderman swinging over buildings"`, the effective prompt used in training would be `"DISNEY Spiderman swinging over buildings"`. When not provided, you would either be training without any additional token or could augment your dataset to apply the token where you wish before starting the training.
|
||||
|
||||
> [!NOTE]
|
||||
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The following settings have been tested at the time of adding CogVideoX LoRA training support:
|
||||
> - Our testing was primarily done on CogVideoX-2b. We will work on CogVideoX-5b and CogVideoX-5b-I2V soon
|
||||
> - One dataset comprised of 70 training videos of resolutions `200 x 480 x 720` (F x H x W). From this, by using frame skipping in data preprocessing, we created two smaller 49-frame and 16-frame datasets for faster experimentation and because the maximum limit recommended by the CogVideoX team is 49 frames. Out of the 70 videos, we created three groups of 10, 25 and 50 videos. All videos were similar in nature of the concept being trained.
|
||||
> - 25+ videos worked best for training new concepts and styles.
|
||||
> - We found that it is better to train with an identifier token that can be specified as `--id_token`. This is similar to Dreambooth-like training but normal finetuning without such a token works too.
|
||||
> - Trained concept seemed to work decently well when combined with completely unrelated prompts. We expect even better results if CogVideoX-5B is finetuned.
|
||||
> - The original repository uses a `lora_alpha` of `1`. We found this not suitable in many runs, possibly due to difference in modeling backends and training settings. Our recommendation is to set to the `lora_alpha` to either `rank` or `rank // 2`.
|
||||
> - If you're training on data whose captions generate bad results with the original model, a `rank` of 64 and above is good and also the recommendation by the team behind CogVideoX. If the generations are already moderately good on your training captions, a `rank` of 16/32 should work. We found that setting the rank too low, say `4`, is not ideal and doesn't produce promising results.
|
||||
> - The authors of CogVideoX recommend 4000 training steps and 100 training videos overall to achieve the best result. While that might yield the best results, we found from our limited experimentation that 2000 steps and 25 videos could also be sufficient.
|
||||
> - When using the Prodigy opitimizer for training, one can follow the recommendations from [this](https://huggingface.co/blog/sdxl_lora_advanced_script) blog. Prodigy tends to overfit quickly. From my very limited testing, I found a learning rate of `0.5` to be suitable in addition to `--prodigy_use_bias_correction`, `prodigy_safeguard_warmup` and `--prodigy_decouple`.
|
||||
> - The recommended learning rate by the CogVideoX authors and from our experimentation with Adam/AdamW is between `1e-3` and `1e-4` for a dataset of 25+ videos.
|
||||
>
|
||||
> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data.
|
||||
|
||||
<!-- TODO: Test finetuning with CogVideoX-5b and CogVideoX-5b-I2V and update scripts accordingly -->
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16)
|
||||
# pipe.load_lora_weights("/path/to/lora/weights", adapter_name="cogvideox-lora") # Or,
|
||||
pipe.load_lora_weights("my-awesome-hf-username/my-awesome-lora-name", adapter_name="cogvideox-lora") # If loading from the HF Hub
|
||||
pipe.to("cuda")
|
||||
|
||||
# Assuming lora_alpha=32 and rank=64 for training. If different, set accordingly
|
||||
pipe.set_adapters(["cogvideox-lora"], [32 / 64])
|
||||
|
||||
prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
|
||||
frames = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
||||
export_to_video(frames, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
|
||||
## Reduce memory usage
|
||||
|
||||
While testing using the diffusers library, all optimizations included in the diffusers library were enabled. This
|
||||
scheme has not been tested for actual memory usage on devices outside of **NVIDIA A100 / H100** architectures.
|
||||
Generally, this scheme can be adapted to all **NVIDIA Ampere architecture** and above devices. If optimizations are
|
||||
disabled, memory consumption will multiply, with peak memory usage being about 3 times the value in the table.
|
||||
However, speed will increase by about 3-4 times. You can selectively disable some optimizations, including:
|
||||
|
||||
```
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.vae.enable_tiling()
|
||||
```
|
||||
|
||||
+ For multi-GPU inference, the `enable_sequential_cpu_offload()` optimization needs to be disabled.
|
||||
+ Using INT8 models will slow down inference, which is done to accommodate lower-memory GPUs while maintaining minimal
|
||||
video quality loss, though inference speed will significantly decrease.
|
||||
+ The CogVideoX-2B model was trained in `FP16` precision, and all CogVideoX-5B models were trained in `BF16` precision.
|
||||
We recommend using the precision in which the model was trained for inference.
|
||||
+ [PytorchAO](https://github.com/pytorch/ao) and [Optimum-quanto](https://github.com/huggingface/optimum-quanto/) can be
|
||||
used to quantize the text encoder, transformer, and VAE modules to reduce the memory requirements of CogVideoX. This
|
||||
allows the model to run on free T4 Colabs or GPUs with smaller memory! Also, note that TorchAO quantization is fully
|
||||
compatible with `torch.compile`, which can significantly improve inference speed. FP8 precision must be used on
|
||||
devices with NVIDIA H100 and above, requiring source installation of `torch`, `torchao`, `diffusers`, and `accelerate`
|
||||
Python packages. CUDA 12.4 is recommended.
|
||||
+ The inference speed tests also used the above memory optimization scheme. Without memory optimization, inference speed
|
||||
increases by about 10%. Only the `diffusers` version of the model supports quantization.
|
||||
+ The model only supports English input; other languages can be translated into English for use via large model
|
||||
refinement.
|
||||
+ The memory usage of model fine-tuning is tested in an `8 * H100` environment, and the program automatically
|
||||
uses `Zero 2` optimization. If a specific number of GPUs is marked in the table, that number or more GPUs must be used
|
||||
for fine-tuning.
|
||||
|
||||
|
||||
| **Attribute** | **CogVideoX-2B** | **CogVideoX-5B** |
|
||||
| ------------------------------------ | ---------------------------------------------------------------------- | ---------------------------------------------------------------------- |
|
||||
| **Model Name** | CogVideoX-2B | CogVideoX-5B |
|
||||
| **Inference Precision** | FP16* (Recommended), BF16, FP32, FP8*, INT8, Not supported INT4 | BF16 (Recommended), FP16, FP32, FP8*, INT8, Not supported INT4 |
|
||||
| **Single GPU Inference VRAM** | FP16: Using diffusers 12.5GB* INT8: Using diffusers with torchao 7.8GB* | BF16: Using diffusers 20.7GB* INT8: Using diffusers with torchao 11.4GB* |
|
||||
| **Multi GPU Inference VRAM** | FP16: Using diffusers 10GB* | BF16: Using diffusers 15GB* |
|
||||
| **Inference Speed** | Single A100: ~90 seconds, Single H100: ~45 seconds | Single A100: ~180 seconds, Single H100: ~90 seconds |
|
||||
| **Fine-tuning Precision** | FP16 | BF16 |
|
||||
| **Fine-tuning VRAM Consumption** | 47 GB (bs=1, LORA) 61 GB (bs=2, LORA) 62GB (bs=1, SFT) | 63 GB (bs=1, LORA) 80 GB (bs=2, LORA) 75GB (bs=1, SFT) |
|
||||
@@ -177,7 +177,7 @@ transformer = FluxTransformer2DModel.from_pretrained(
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. You can also try `print(transformer.hf_device_map)` to see how the transformer model is sharded across devices.
|
||||
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models.
|
||||
|
||||
Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.
|
||||
|
||||
|
||||
@@ -75,12 +75,6 @@ image
|
||||
|
||||

|
||||
|
||||
<Tip>
|
||||
|
||||
By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Merge adapters
|
||||
|
||||
You can also merge different adapter checkpoints for inference to blend their styles together.
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
# CogVideoX
|
||||
|
||||
CogVideoX is a text-to-video generation model focused on creating more coherent videos aligned with a prompt. It achieves this using several methods.
|
||||
|
||||
- a 3D variational autoencoder that compresses videos spatially and temporally, improving compression rate and video accuracy.
|
||||
|
||||
- an expert transformer block to help align text and video, and a 3D full attention module for capturing and creating spatially and temporally accurate videos.
|
||||
|
||||
|
||||
|
||||
## Load model checkpoints
|
||||
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
|
||||
|
||||
|
||||
```py
|
||||
from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline
|
||||
pipe = CogVideoXPipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-2b",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-5b-I2V",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
## Text-to-Video
|
||||
For text-to-video, pass a text prompt. By default, CogVideoX generates a 720x480 video for the best results.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
prompt = "An elderly gentleman, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea."
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-5b",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.vae.enable_tiling()
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=1,
|
||||
num_inference_steps=50,
|
||||
num_frames=49,
|
||||
guidance_scale=6,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
|
||||
```
|
||||
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_out.gif" alt="generated image of an astronaut in a jungle"/>
|
||||
</div>
|
||||
|
||||
|
||||
## Image-to-Video
|
||||
|
||||
|
||||
You'll use the [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V) checkpoint for this guide.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import CogVideoXImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
|
||||
prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
|
||||
image = load_image(image="cogvideox_rocket.png")
|
||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-5b-I2V",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
pipe.vae.enable_tiling()
|
||||
pipe.vae.enable_slicing()
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
num_videos_per_prompt=1,
|
||||
num_inference_steps=50,
|
||||
num_frames=49,
|
||||
guidance_scale=6,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_rocket.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_outrocket.gif"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated video</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -23,59 +23,6 @@ This guide will show you how to generate videos, how to configure video model pa
|
||||
|
||||
[Stable Video Diffusions (SVD)](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid), [I2VGen-XL](https://huggingface.co/ali-vilab/i2vgen-xl/), [AnimateDiff](https://huggingface.co/guoyww/animatediff), and [ModelScopeT2V](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b) are popular models used for video diffusion. Each model is distinct. For example, AnimateDiff inserts a motion modeling module into a frozen text-to-image model to generate personalized animated images, whereas SVD is entirely pretrained from scratch with a three-stage training process to generate short high-quality videos.
|
||||
|
||||
[CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) is another popular video generation model. The model is a multidimensional transformer that integrates text, time, and space. It employs full attention in the attention module and includes an expert block at the layer level to spatially align text and video.
|
||||
|
||||
### CogVideoX
|
||||
|
||||
[CogVideoX](../api/pipelines/cogvideox) uses a 3D Variational Autoencoder (VAE) to compress videos along the spatial and temporal dimensions.
|
||||
|
||||
Begin by loading the [`CogVideoXPipeline`] and passing an initial text or image to generate a video.
|
||||
<Tip>
|
||||
|
||||
CogVideoX is available for image-to-video and text-to-video. [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V) uses the [`CogVideoXImageToVideoPipeline`] for image-to-video. [THUDM/CogVideoX-5b](https://huggingface.co/THUDM/CogVideoX-5b) and [THUDM/CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b) are available for text-to-video with the [`CogVideoXPipeline`].
|
||||
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import CogVideoXImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
|
||||
prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
|
||||
image = load_image(image="cogvideox_rocket.png")
|
||||
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
||||
"THUDM/CogVideoX-5b-I2V",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
pipe.vae.enable_tiling()
|
||||
pipe.vae.enable_slicing()
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
num_videos_per_prompt=1,
|
||||
num_inference_steps=50,
|
||||
num_frames=49,
|
||||
guidance_scale=6,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_rocket.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cogvideox/cogvideox_outrocket.gif"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated video</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
### Stable Video Diffusion
|
||||
|
||||
[SVD](../api/pipelines/svd) is based on the Stable Diffusion 2.1 model and it is trained on images, then low-resolution videos, and finally a smaller dataset of high-resolution videos. This model generates a short 2-4 second video from an initial image. You can learn more details about model, like micro-conditioning, in the [Stable Video Diffusion](../using-diffusers/svd) guide.
|
||||
|
||||
@@ -1,353 +0,0 @@
|
||||
# Advanced diffusion training examples
|
||||
|
||||
## Train Dreambooth LoRA with Flux.1 Dev
|
||||
> [!TIP]
|
||||
> 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script).
|
||||
> As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like flux, stable diffusion given just a few(3~5) images of a subject.
|
||||
|
||||
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
|
||||
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
||||
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
|
||||
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
|
||||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
|
||||
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
|
||||
|
||||
The `train_dreambooth_lora_flux_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_flux.py`, with
|
||||
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
|
||||
[ostris](https://x.com/ostrisai):[ai-toolkit](https://github.com/ostris/ai-toolkit), [bghira](https://github.com/bghira):[SimpleTuner](https://github.com/bghira/SimpleTuner), [Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
|
||||
|
||||
> [!NOTE]
|
||||
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
|
||||
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/advanced_diffusion_training` folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
|
||||
|
||||
### Pivotal Tuning (and more)
|
||||
**Training with text encoder(s)**
|
||||
|
||||
Alongside the Transformer, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
|
||||
available with `train_dreambooth_lora_flux_advanced.py`, in the advanced script **pivotal tuning** is also supported.
|
||||
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
|
||||
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
|
||||
We then optimize the newly-inserted token embeddings to represent the new concept.
|
||||
|
||||
To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
|
||||
Please keep the following points in mind:
|
||||
|
||||
* Flux uses two text encoders - [CLIP](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder) & [T5](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder_2) , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only.
|
||||
To activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`.
|
||||
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
|
||||
* **pure textual inversion** - to support the full range from pivotal tuning to textual inversion we introduce `--train_transformer_frac` which controls the amount of epochs the transformer LoRA layers are trained. By default, `--train_transformer_frac==1`, to trigger a textual inversion run set `--train_transformer_frac==0`. Values between 0 and 1 are supported as well, and we welcome the community to experiment w/ different settings and share the results!
|
||||
* **token initializer** - similar to the original textual inversion work, you can specify a token of your choosing as the starting point for training. By default, when enabling `--train_text_encoder_ti`, the new inserted tokens are initialized randomly. You can specify a token in `--initializer_token` such that the starting point for the trained embeddings will be the embeddings associated with your chosen `--initializer_token`.
|
||||
|
||||
## Training examples
|
||||
|
||||
Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./3d_icon"
|
||||
snapshot_download(
|
||||
"LinoyTsaban/3d_icon",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
Let's review some of the advanced features we're going to be using for this example:
|
||||
- **custom captions**:
|
||||
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
|
||||
```bash
|
||||
pip install datasets
|
||||
```
|
||||
|
||||
Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
|
||||
|
||||
```
|
||||
--dataset_name=./3d_icon
|
||||
--caption_column=prompt
|
||||
```
|
||||
|
||||
You can also load a dataset straight from by specifying it's name in `dataset_name`.
|
||||
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
|
||||
|
||||
- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
|
||||
- **pivotal tuning**
|
||||
|
||||
### Example #1: Pivotal tuning
|
||||
**Now, we can launch training:**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
|
||||
export DATASET_NAME="./3d_icon"
|
||||
export OUTPUT_DIR="3d-icon-Flux-LoRA"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux_advanced.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--instance_prompt="3d icon in the style of TOK" \
|
||||
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--caption_column="prompt" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--repeats=1 \
|
||||
--report_to="wandb"\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--optimizer="prodigy"\
|
||||
--train_text_encoder_ti\
|
||||
--train_text_encoder_ti_frac=0.5\
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--rank=8 \
|
||||
--max_train_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
### Example #2: Pivotal tuning with T5
|
||||
Now let's try that with T5 as well, so instead of only optimizing the CLIP embeddings associated with newly inserted tokens, we'll optimize
|
||||
the T5 embeddings as well. We can do this by simply adding `--enable_t5_ti` to the previous configuration:
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
|
||||
export DATASET_NAME="./3d_icon"
|
||||
export OUTPUT_DIR="3d-icon-Flux-LoRA"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux_advanced.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--instance_prompt="3d icon in the style of TOK" \
|
||||
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--caption_column="prompt" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--repeats=1 \
|
||||
--report_to="wandb"\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--optimizer="prodigy"\
|
||||
--train_text_encoder_ti\
|
||||
--enable_t5_ti\
|
||||
--train_text_encoder_ti_frac=0.5\
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--rank=8 \
|
||||
--max_train_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Example #3: Textual Inversion
|
||||
To explore a pure textual inversion - i.e. only optimizing the text embeddings w/o training transformer LoRA layers, we
|
||||
can set the value for `--train_transformer_frac` - which is responsible for the percent of epochs in which the transformer is
|
||||
trained. By setting `--train_transformer_frac == 0` and enabling `--train_text_encoder_ti` we trigger a textual inversion train
|
||||
run.
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
|
||||
export DATASET_NAME="./3d_icon"
|
||||
export OUTPUT_DIR="3d-icon-Flux-LoRA"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux_advanced.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--instance_prompt="3d icon in the style of TOK" \
|
||||
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--caption_column="prompt" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--repeats=1 \
|
||||
--report_to="wandb"\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--optimizer="prodigy"\
|
||||
--train_text_encoder_ti\
|
||||
--enable_t5_ti\
|
||||
--train_text_encoder_ti_frac=0.5\
|
||||
--train_transformer_frac=0\
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--rank=8 \
|
||||
--max_train_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
### Inference - pivotal tuning
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
1. starting with loading the transformer lora weights
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, upload_file
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
username = "linoyts"
|
||||
repo_id = f"{username}/3d-icon-Flux-LoRA"
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
|
||||
|
||||
|
||||
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
|
||||
```
|
||||
2. now we load the pivotal tuning embeddings
|
||||
💡note that if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder
|
||||
|
||||
```python
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
|
||||
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model")
|
||||
|
||||
state_dict = load_file(embedding_path)
|
||||
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
||||
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
||||
# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`
|
||||
pipe.load_textual_inversion(state_dict["t5"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
||||
```
|
||||
|
||||
3. let's generate images
|
||||
|
||||
```python
|
||||
instance_token = "<s0><s1>"
|
||||
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
|
||||
|
||||
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image.save("llama.png")
|
||||
```
|
||||
|
||||
### Inference - pure textual inversion
|
||||
In this case, we don't load transformer layers as before, since we only optimize the text embeddings. The output of a textual inversion train run is a
|
||||
`.safetensors` file containing the trained embeddings for the new tokens either for the CLIP encoder, or for both encoders (CLIP and T5)
|
||||
|
||||
1. starting with loading the embeddings.
|
||||
💡note that here too, if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, upload_file
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
username = "linoyts"
|
||||
repo_id = f"{username}/3d-icon-Flux-LoRA"
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
|
||||
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
|
||||
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model")
|
||||
|
||||
state_dict = load_file(embedding_path)
|
||||
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
||||
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
||||
# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`
|
||||
pipe.load_textual_inversion(state_dict["t5"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
||||
```
|
||||
2. let's generate images
|
||||
|
||||
```python
|
||||
instance_token = "<s0><s1>"
|
||||
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
|
||||
|
||||
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image.save("llama.png")
|
||||
```
|
||||
|
||||
### Comfy UI / AUTOMATIC1111 Inference
|
||||
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
|
||||
|
||||
**AUTOMATIC1111 / SD.Next** \
|
||||
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
|
||||
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
|
||||
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
|
||||
|
||||
You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
|
||||
|
||||
**ComfyUI** \
|
||||
In ComfyUI we will load a LoRA and a textual embedding at the same time.
|
||||
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
|
||||
@@ -1,8 +0,0 @@
|
||||
accelerate>=0.31.0
|
||||
torchvision
|
||||
transformers>=4.41.2
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft>=0.11.1
|
||||
sentencepiece
|
||||
@@ -1,283 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
|
||||
script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py"
|
||||
|
||||
def test_dreambooth_lora_flux(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_text_encoder_flux(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--train_text_encoder
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
starts_with_expected_prefix = all(
|
||||
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_expected_prefix)
|
||||
|
||||
def test_dreambooth_lora_pivotal_tuning_flux_clip(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--train_text_encoder_ti
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
# make sure embeddings were also saved
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
textual_inversion_state_dict = safetensors.torch.load_file(
|
||||
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
|
||||
)
|
||||
is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys())
|
||||
self.assertTrue(is_clip)
|
||||
|
||||
# when performing pivotal tuning, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--train_text_encoder_ti
|
||||
--enable_t5_ti
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
# make sure embeddings were also saved
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
textual_inversion_state_dict = safetensors.torch.load_file(
|
||||
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
|
||||
)
|
||||
is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys())
|
||||
self.assertTrue(is_te)
|
||||
|
||||
# when performing pivotal tuning, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,11 +10,6 @@ In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-de
|
||||
|
||||
At the moment, LoRA finetuning has only been tested for [CogVideoX-2b](https://huggingface.co/THUDM/CogVideoX-2b).
|
||||
|
||||
> [!NOTE]
|
||||
> The scripts for CogVideoX come with limited support and may not be fully compatible with different training techniques. They are not feature-rich either and simply serve as minimal examples of finetuning to take inspiration from and improve.
|
||||
>
|
||||
> A repository containing memory-optimized finetuning scripts with support for multiple resolutions, dataset preparation, captioning, etc. is available [here](https://github.com/a-r-r-o-w/cogvideox-factory), which will be maintained jointly by the CogVideoX and Diffusers team.
|
||||
|
||||
## Data Preparation
|
||||
|
||||
The training scripts accepts data in two formats.
|
||||
@@ -137,8 +132,6 @@ Assuming you are training on 50 videos of a similar concept, we have found 1500-
|
||||
- 1500 steps on 50 videos would correspond to `30` training epochs
|
||||
- 4000 steps on 100 videos would correspond to `40` training epochs
|
||||
|
||||
The following bash script launches training for text-to-video lora.
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
@@ -179,8 +172,6 @@ accelerate launch --gpu_ids $GPU_IDS examples/cogvideo/train_cogvideox_lora.py \
|
||||
--report_to wandb
|
||||
```
|
||||
|
||||
For launching image-to-video finetuning instead, run the `train_cogvideox_image_to_video_lora.py` file instead. Additionally, you will have to pass `--validation_images` as paths to initial images corresponding to `--validation_prompts` for I2V validation to work.
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
* `--report_to wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
@@ -189,7 +180,6 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen
|
||||
|
||||
> [!TIP]
|
||||
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
|
||||
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The following settings have been tested at the time of adding CogVideoX LoRA training support:
|
||||
@@ -206,6 +196,8 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen
|
||||
>
|
||||
> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data.
|
||||
|
||||
<!-- TODO: Test finetuning with CogVideoX-5b and CogVideoX-5b-I2V and update scripts accordingly -->
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a lora model, the inference can be done simply loading the lora weights into the `CogVideoXPipeline`.
|
||||
@@ -234,5 +226,3 @@ prompt = (
|
||||
frames = pipe(prompt, guidance_scale=6, use_dynamic_cfg=True).frames[0]
|
||||
export_to_video(frames, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
If you've trained a LoRA for `CogVideoXImageToVideoPipeline` instead, everything in the above example remains the same except you must also pass an image as initial condition for generation.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -21,28 +21,27 @@ import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as TT
|
||||
import transformers
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms.functional import resize
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
||||
from diffusers.training_utils import cast_training_params, free_memory
|
||||
from diffusers.training_utils import (
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
)
|
||||
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -218,12 +217,6 @@ def get_args():
|
||||
default=720,
|
||||
help="All input videos are resized to this width.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_reshape_mode",
|
||||
type=str,
|
||||
default="center",
|
||||
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
|
||||
)
|
||||
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
|
||||
parser.add_argument(
|
||||
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
|
||||
@@ -423,7 +416,6 @@ class VideoDataset(Dataset):
|
||||
video_column: str = "video",
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
video_reshape_mode: str = "center",
|
||||
fps: int = 8,
|
||||
max_num_frames: int = 49,
|
||||
skip_frames_start: int = 0,
|
||||
@@ -440,7 +432,6 @@ class VideoDataset(Dataset):
|
||||
self.video_column = video_column
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.video_reshape_mode = video_reshape_mode
|
||||
self.fps = fps
|
||||
self.max_num_frames = max_num_frames
|
||||
self.skip_frames_start = skip_frames_start
|
||||
@@ -544,38 +535,6 @@ class VideoDataset(Dataset):
|
||||
|
||||
return instance_prompts, instance_videos
|
||||
|
||||
def _resize_for_rectangle_crop(self, arr):
|
||||
image_size = self.height, self.width
|
||||
reshape_mode = self.video_reshape_mode
|
||||
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
else:
|
||||
arr = resize(
|
||||
arr,
|
||||
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
h, w = arr.shape[2], arr.shape[3]
|
||||
arr = arr.squeeze(0)
|
||||
|
||||
delta_h = h - image_size[0]
|
||||
delta_w = w - image_size[1]
|
||||
|
||||
if reshape_mode == "random" or reshape_mode == "none":
|
||||
top = np.random.randint(0, delta_h + 1)
|
||||
left = np.random.randint(0, delta_w + 1)
|
||||
elif reshape_mode == "center":
|
||||
top, left = delta_h // 2, delta_w // 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
|
||||
return arr
|
||||
|
||||
def _preprocess_data(self):
|
||||
try:
|
||||
import decord
|
||||
@@ -586,14 +545,15 @@ class VideoDataset(Dataset):
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
progress_dataset_bar = tqdm(
|
||||
range(0, len(self.instance_video_paths)),
|
||||
desc="Loading progress resize and crop videos",
|
||||
)
|
||||
videos = []
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
||||
]
|
||||
)
|
||||
|
||||
for filename in self.instance_video_paths:
|
||||
video_reader = decord.VideoReader(uri=filename.as_posix())
|
||||
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
|
||||
video_num_frames = len(video_reader)
|
||||
|
||||
start_frame = min(self.skip_frames_start, video_num_frames)
|
||||
@@ -619,16 +579,10 @@ class VideoDataset(Dataset):
|
||||
assert (selected_num_frames - 1) % 4 == 0
|
||||
|
||||
# Training transforms
|
||||
frames = (frames - 127.5) / 127.5
|
||||
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
|
||||
progress_dataset_bar.set_description(
|
||||
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
|
||||
)
|
||||
frames = self._resize_for_rectangle_crop(frames)
|
||||
videos.append(frames.contiguous()) # [F, C, H, W]
|
||||
progress_dataset_bar.update(1)
|
||||
frames = frames.float()
|
||||
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
|
||||
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
|
||||
|
||||
progress_dataset_bar.close()
|
||||
return videos
|
||||
|
||||
|
||||
@@ -743,13 +697,8 @@ def log_validation(
|
||||
|
||||
videos = []
|
||||
for _ in range(args.num_validation_videos):
|
||||
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
|
||||
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
|
||||
|
||||
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
|
||||
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
|
||||
|
||||
videos.append(image_pil)
|
||||
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
|
||||
videos.append(video)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
@@ -777,8 +726,7 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
|
||||
del pipe
|
||||
free_memory()
|
||||
clear_objs_and_retain_memory([pipe])
|
||||
|
||||
return videos
|
||||
|
||||
@@ -922,7 +870,7 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and args.optimizer.lower() not in ["adam", "adamw"]:
|
||||
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
|
||||
logger.warning(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
@@ -1211,7 +1159,7 @@ def main(args):
|
||||
)
|
||||
use_deepspeed_scheduler = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
)
|
||||
|
||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||
@@ -1225,7 +1173,6 @@ def main(args):
|
||||
video_column=args.video_column,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
video_reshape_mode=args.video_reshape_mode,
|
||||
fps=args.fps,
|
||||
max_num_frames=args.max_num_frames,
|
||||
skip_frames_start=args.skip_frames_start,
|
||||
@@ -1234,28 +1181,19 @@ def main(args):
|
||||
id_token=args.id_token,
|
||||
)
|
||||
|
||||
def encode_video(video, bar):
|
||||
bar.update(1)
|
||||
def encode_video(video):
|
||||
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
|
||||
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
latent_dist = vae.encode(video).latent_dist
|
||||
return latent_dist
|
||||
|
||||
progress_encode_bar = tqdm(
|
||||
range(0, len(train_dataset.instance_videos)),
|
||||
desc="Loading Encode videos",
|
||||
)
|
||||
train_dataset.instance_videos = [
|
||||
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
|
||||
]
|
||||
progress_encode_bar.close()
|
||||
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
|
||||
|
||||
def collate_fn(examples):
|
||||
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
|
||||
videos = torch.cat(videos)
|
||||
videos = videos.permute(0, 2, 1, 3, 4)
|
||||
videos = videos.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
return {
|
||||
@@ -1377,7 +1315,7 @@ def main(args):
|
||||
models_to_accumulate = [transformer]
|
||||
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
model_input = batch["videos"].to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode prompts
|
||||
@@ -1456,7 +1394,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1496,6 +1434,7 @@ def main(args):
|
||||
args.pretrained_model_name_or_path,
|
||||
transformer=unwrap_model(transformer),
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=unwrap_model(vae),
|
||||
scheduler=scheduler,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
@@ -1539,10 +1478,6 @@ def main(args):
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
)
|
||||
|
||||
# Cleanup trained models to save memory
|
||||
del transformer
|
||||
free_memory()
|
||||
|
||||
# Final test inference
|
||||
pipe = CogVideoXPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
|
||||
@@ -73,8 +73,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
|
||||
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
|
||||
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
|
||||
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
|
||||
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
|
||||
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffsuion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -86,17 +85,17 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
|
||||
|
||||
### Flux with CFG
|
||||
|
||||
Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
|
||||
Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md).
|
||||
|
||||
Example usage:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
custom_pipeline="pipeline_flux_with_cfg"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
@@ -104,10 +103,10 @@ prompt = "a watercolor painting of a unicorn"
|
||||
negative_prompt = "pink"
|
||||
|
||||
img = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
true_cfg=1.5,
|
||||
guidance_scale=3.5,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
true_cfg=1.5,
|
||||
guidance_scale=3.5,
|
||||
num_images_per_prompt=1,
|
||||
generator=torch.manual_seed(0)
|
||||
).images[0]
|
||||
@@ -2657,7 +2656,7 @@ image with mask mech_painted.png
|
||||
|
||||
<img src=https://github.com/noskill/diffusers/assets/733626/c334466a-67fe-4377-9ff7-f46021b9c224 width="25%" >
|
||||
|
||||
result:
|
||||
result:
|
||||
|
||||
<img src=https://github.com/noskill/diffusers/assets/733626/5043fb57-a785-4606-a5ba-a36704f7cb42 width="25%" >
|
||||
|
||||
@@ -4325,51 +4324,6 @@ image = pipe(
|
||||
|
||||
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
|
||||
|
||||
### 🪆Matryoshka Diffusion Models
|
||||
|
||||

|
||||
|
||||
The Abstract of the paper:
|
||||
>Diffusion models are the _de-facto_ approach for generating high-quality images and videos but learning high-dimensional models remains a formidable task due to computational and optimization challenges. Existing methods often resort to training cascaded models in pixel space, or using a downsampled latent space of a separately trained auto-encoder. In this paper, we introduce Matryoshka Diffusion (MDM), **a novel framework for high-resolution image and video synthesis**. We propose a diffusion process that denoises inputs at multiple resolutions jointly and uses a **NestedUNet** architecture where features and parameters for small scale inputs are nested within those of the large scales. In addition, MDM enables a progressive training schedule from lower to higher resolutions which leads to significant improvements in optimization for high-resolution generation. We demonstrate the effectiveness of our approach on various benchmarks, including class-conditioned image generation, high-resolution text-to-image, and text-to-video applications. Remarkably, we can train a **_single pixel-space model_ at resolutions of up to 1024 × 1024 pixels**, demonstrating strong zero shot generalization using the **CC12M dataset, which contains only 12 million images**. Code and pre-trained checkpoints are released at https://github.com/apple/ml-mdm.
|
||||
|
||||
- `64×64, nesting_level=0`: 1.719 GiB. With `50` DDIM inference steps:
|
||||
|
||||
**64x64**
|
||||
:-------------------------:
|
||||
| <img src="https://github.com/user-attachments/assets/9e7bb2cd-45a0-4bd1-adb8-23e283baed39" width="222" height="222" alt="bird_64"> |
|
||||
|
||||
- `256×256, nesting_level=1`: 1.776 GiB. With `150` DDIM inference steps:
|
||||
|
||||
**64x64** | **256x256**
|
||||
:-------------------------:|:-------------------------:
|
||||
| <img src="https://github.com/user-attachments/assets/6b724c2e-5e6a-4b63-9b65-c1182cbb67e0" width="222" height="222" alt="64x64"> | <img src="https://github.com/user-attachments/assets/7dbab2ad-bf40-4a73-ab04-f178347cb7d5" width="222" height="222" alt="256x256"> |
|
||||
|
||||
- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible. With `250` DDIM inference steps:
|
||||
|
||||
**64x64** | **256x256** | **1024x1024**
|
||||
:-------------------------:|:-------------------------:|:-------------------------:
|
||||
| <img src="https://github.com/user-attachments/assets/4a9454e4-e20a-4736-a196-270e2ae796c0" width="222" height="222" alt="64x64"> | <img src="https://github.com/user-attachments/assets/4a96555d-0fda-4303-82b1-a4d886f770b9" width="222" height="222" alt="256x256"> | <img src="https://github.com/user-attachments/assets/e0239b7a-ab73-4d45-8f3e-b4e6b4b50abe" width="222" height="222" alt="1024x1024"> |
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import make_image_grid
|
||||
|
||||
# nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64
|
||||
pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
||||
nesting_level=0,
|
||||
trust_remote_code=False, # One needs to give permission for this code to run
|
||||
).to("cuda")
|
||||
|
||||
prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
|
||||
prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
|
||||
negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
|
||||
image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
|
||||
make_image_grid(image, rows=1, cols=len(image))
|
||||
|
||||
# pipe.change_nesting_level(<int>) # 0, 1, or 2
|
||||
# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
|
||||
```
|
||||
|
||||
# Perturbed-Attention Guidance
|
||||
|
||||
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
|
||||
|
||||
@@ -898,16 +898,13 @@ class GaussianSmoothing(nn.Module):
|
||||
Apply gaussian smoothing on a
|
||||
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
||||
in the input using a depthwise convolution.
|
||||
|
||||
Args:
|
||||
channels (`int` or `sequence`):
|
||||
Number of channels of the input tensors. The output will have this number of channels as well.
|
||||
kernel_size (`int` or `sequence`):
|
||||
Size of the Gaussian kernel.
|
||||
sigma (`float` or `sequence`):
|
||||
Standard deviation of the Gaussian kernel.
|
||||
dim (`int`, *optional*, defaults to `2`):
|
||||
The number of dimensions of the data. Default is 2 (spatial dimensions).
|
||||
Arguments:
|
||||
channels (int, sequence): Number of channels of the input tensors. Output will
|
||||
have this number of channels as well.
|
||||
kernel_size (int, sequence): Size of the gaussian kernel.
|
||||
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
||||
dim (int, optional): The number of dimensions of the data.
|
||||
Default value is 2 (spatial).
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, sigma, dim=2):
|
||||
@@ -947,14 +944,10 @@ class GaussianSmoothing(nn.Module):
|
||||
def forward(self, input):
|
||||
"""
|
||||
Apply gaussian filter to input.
|
||||
|
||||
Args:
|
||||
input (`torch.Tensor` of shape `(N, C, H, W)`):
|
||||
Input to apply Gaussian filter on.
|
||||
|
||||
Arguments:
|
||||
input (torch.Tensor): Input to apply gaussian filter on.
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The filtered output tensor with the same shape as the input.
|
||||
filtered (torch.Tensor): Filtered output.
|
||||
"""
|
||||
return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding="same")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,430 +0,0 @@
|
||||
# ControlNet training example for FLUX
|
||||
|
||||
The `train_controlnet_flux.py` script shows how to implement the ControlNet training procedure and adapt it for [FLUX](https://github.com/black-forest-labs/flux).
|
||||
|
||||
Training script provided by LibAI, which is an institution dedicated to the progress and achievement of artificial general intelligence. LibAI is the developer of [cutout.pro](https://www.cutout.pro/) and [promeai.pro](https://www.promeai.pro/).
|
||||
> [!NOTE]
|
||||
> **Memory consumption**
|
||||
>
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
|
||||
|
||||
> **Gated access**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: `huggingface-cli login`
|
||||
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/controlnet` folder and run
|
||||
```bash
|
||||
pip install -r requirements_flux.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
## Custom Datasets
|
||||
|
||||
We support dataset formats:
|
||||
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. To use our example, add `--dataset_name=fusing/fill50k \` to the script and remove line `--jsonl_for_train` mentioned below.
|
||||
|
||||
|
||||
We also support importing data from jsonl(xxx.jsonl),using `--jsonl_for_train` to enable it, here is a brief example of jsonl files:
|
||||
```sh
|
||||
{"image": "xxx", "text": "xxx", "conditioning_image": "xxx"}
|
||||
{"image": "xxx", "text": "xxx", "conditioning_image": "xxx"}
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
Our training examples use two test conditioning images. They can be downloaded by running
|
||||
|
||||
```sh
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
|
||||
```
|
||||
|
||||
Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
|
||||
|
||||
we can define the num_layers, num_single_layers, which determines the size of the control(default values are num_layers=4, num_single_layers=10)
|
||||
|
||||
|
||||
```bash
|
||||
accelerate launch train_controlnet_flux.py \
|
||||
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
|
||||
--dataset_name=fusing/fill50k \
|
||||
--conditioning_image_column=conditioning_image \
|
||||
--image_column=image \
|
||||
--caption_column=text \
|
||||
--output_dir="path to save model" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=512 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_train_steps=15000 \
|
||||
--validation_steps=100 \
|
||||
--checkpointing_steps=200 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--report_to="wandb" \
|
||||
--num_double_layers=4 \
|
||||
--num_single_layers=0 \
|
||||
--seed=42 \
|
||||
--push_to_hub \
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases.
|
||||
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 80GB A100 GPU.
|
||||
|
||||
### Inference
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
|
||||
from diffusers.models.controlnet_flux import FluxControlNetModel
|
||||
|
||||
base_model = 'black-forest-labs/FLUX.1-dev'
|
||||
controlnet_model = 'promeai/FLUX.1-controlnet-lineart-promeai'
|
||||
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
||||
pipe = FluxControlNetPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
# enable memory optimizations
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
control_image = load_image("https://huggingface.co/promeai/FLUX.1-controlnet-lineart-promeai/resolve/main/images/example-control.jpg")resize((1024, 1024))
|
||||
prompt = "cute anime girl with massive fluffy fennec ears and a big fluffy tail blonde messy long hair blue eyes wearing a maid outfit with a long black gold leaf pattern dress and a white apron mouth open holding a fancy black forest cake with candles on top in the kitchen of an old dark Victorian mansion lit by candlelight with a bright window to the foggy forest and very expensive stuff everywhere"
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
control_image=control_image,
|
||||
controlnet_conditioning_scale=0.6,
|
||||
num_inference_steps=28,
|
||||
guidance_scale=3.5,
|
||||
).images[0]
|
||||
image.save("./output.png")
|
||||
```
|
||||
|
||||
## Apply Deepspeed Zero3
|
||||
|
||||
This is an experimental process, I am not sure if it is suitable for everyone, we used this process to successfully train 512 resolution on A100(40g) * 8.
|
||||
Please modify some of the code in the script.
|
||||
### 1.Customize zero3 settings
|
||||
|
||||
Copy the **accelerate_config_zero3.yaml**,modify `num_processes` according to the number of gpus you want to use:
|
||||
|
||||
```bash
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 8
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
### 2.Precompute all inputs (latent, embeddings)
|
||||
|
||||
In the train_controlnet_flux.py, We need to pre-calculate all parameters and put them into batches.So we first need to rewrite the `compute_embeddings` function.
|
||||
|
||||
```python
|
||||
def compute_embeddings(batch, proportion_empty_prompts, vae, flux_controlnet_pipeline, weight_dtype, is_train=True):
|
||||
|
||||
### compute text embeddings
|
||||
prompt_batch = batch[args.caption_column]
|
||||
captions = []
|
||||
for caption in prompt_batch:
|
||||
if random.random() < proportion_empty_prompts:
|
||||
captions.append("")
|
||||
elif isinstance(caption, str):
|
||||
captions.append(caption)
|
||||
elif isinstance(caption, (list, np.ndarray)):
|
||||
# take a random caption if there are multiple
|
||||
captions.append(random.choice(caption) if is_train else caption[0])
|
||||
prompt_batch = captions
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = flux_controlnet_pipeline.encode_prompt(
|
||||
prompt_batch, prompt_2=prompt_batch
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=weight_dtype)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype)
|
||||
text_ids = text_ids.to(dtype=weight_dtype)
|
||||
|
||||
# text_ids [512,3] to [bs,512,3]
|
||||
text_ids = text_ids.unsqueeze(0).expand(prompt_embeds.shape[0], -1, -1)
|
||||
|
||||
### compute latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
return latents
|
||||
|
||||
# vae encode
|
||||
pixel_values = batch["pixel_values"]
|
||||
pixel_values = torch.stack([image for image in pixel_values]).to(dtype=weight_dtype).to(vae.device)
|
||||
pixel_latents_tmp = vae.encode(pixel_values).latent_dist.sample()
|
||||
pixel_latents_tmp = (pixel_latents_tmp - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
pixel_latents = _pack_latents(
|
||||
pixel_latents_tmp,
|
||||
pixel_values.shape[0],
|
||||
pixel_latents_tmp.shape[1],
|
||||
pixel_latents_tmp.shape[2],
|
||||
pixel_latents_tmp.shape[3],
|
||||
)
|
||||
|
||||
control_values = batch["conditioning_pixel_values"]
|
||||
control_values = torch.stack([image for image in control_values]).to(dtype=weight_dtype).to(vae.device)
|
||||
control_latents = vae.encode(control_values).latent_dist.sample()
|
||||
control_latents = (control_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
control_latents = _pack_latents(
|
||||
control_latents,
|
||||
control_values.shape[0],
|
||||
control_latents.shape[1],
|
||||
control_latents.shape[2],
|
||||
control_latents.shape[3],
|
||||
)
|
||||
|
||||
# copied from pipeline_flux_controlnet
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
latent_image_ids = _prepare_latent_image_ids(
|
||||
batch_size=pixel_latents_tmp.shape[0],
|
||||
height=pixel_latents_tmp.shape[2],
|
||||
width=pixel_latents_tmp.shape[3],
|
||||
device=pixel_values.device,
|
||||
dtype=pixel_values.dtype,
|
||||
)
|
||||
|
||||
# unet_added_cond_kwargs = {"pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids}
|
||||
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids, "pixel_latents": pixel_latents, "control_latents": control_latents, "latent_image_ids": latent_image_ids}
|
||||
```
|
||||
|
||||
Because we need images to pass through vae, we need to preprocess the images in the dataset first. At the same time, vae requires more gpu memory, so you may need to modify the `batch_size` below
|
||||
```diff
|
||||
+train_dataset = prepare_train_dataset(train_dataset, accelerator)
|
||||
with accelerator.main_process_first():
|
||||
from datasets.fingerprint import Hasher
|
||||
|
||||
# fingerprint used by the cache for the other processes to load the result
|
||||
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
|
||||
new_fingerprint = Hasher.hash(args)
|
||||
train_dataset = train_dataset.map(
|
||||
- compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=100
|
||||
+ compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=10
|
||||
)
|
||||
|
||||
del text_encoders, tokenizers
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Then get the training dataset ready to be passed to the dataloader.
|
||||
-train_dataset = prepare_train_dataset(train_dataset, accelerator)
|
||||
```
|
||||
### 3.Redefine the behavior of getting batchsize
|
||||
|
||||
Now that we have all the preprocessing done, we need to modify the `collate_fn` function.
|
||||
|
||||
```python
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
|
||||
conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
pixel_latents = torch.stack([torch.tensor(example["pixel_latents"]) for example in examples])
|
||||
pixel_latents = pixel_latents.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
control_latents = torch.stack([torch.tensor(example["control_latents"]) for example in examples])
|
||||
control_latents = control_latents.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
latent_image_ids= torch.stack([torch.tensor(example["latent_image_ids"]) for example in examples])
|
||||
|
||||
prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
|
||||
|
||||
pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
|
||||
text_ids = torch.stack([torch.tensor(example["text_ids"]) for example in examples])
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"conditioning_pixel_values": conditioning_pixel_values,
|
||||
"pixel_latents": pixel_latents,
|
||||
"control_latents": control_latents,
|
||||
"latent_image_ids": latent_image_ids,
|
||||
"prompt_ids": prompt_ids,
|
||||
"unet_added_conditions": {"pooled_prompt_embeds": pooled_prompt_embeds, "time_ids": text_ids},
|
||||
}
|
||||
```
|
||||
Finally, we just need to modify the way of obtaining various parameters during training.
|
||||
```python
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(flux_controlnet):
|
||||
# Convert images to latent space
|
||||
pixel_latents = batch["pixel_latents"].to(dtype=weight_dtype)
|
||||
control_image = batch["control_latents"].to(dtype=weight_dtype)
|
||||
latent_image_ids = batch["latent_image_ids"].to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(pixel_latents).to(accelerator.device).to(dtype=weight_dtype)
|
||||
bsz = pixel_latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
t = torch.sigmoid(torch.randn((bsz,), device=accelerator.device, dtype=weight_dtype))
|
||||
|
||||
# apply flow matching
|
||||
noisy_latents = (
|
||||
1 - t.unsqueeze(1).unsqueeze(2).repeat(1, pixel_latents.shape[1], pixel_latents.shape[2])
|
||||
) * pixel_latents + t.unsqueeze(1).unsqueeze(2).repeat(
|
||||
1, pixel_latents.shape[1], pixel_latents.shape[2]
|
||||
) * noise
|
||||
|
||||
guidance_vec = torch.full(
|
||||
(noisy_latents.shape[0],), 3.5, device=noisy_latents.device, dtype=weight_dtype
|
||||
)
|
||||
|
||||
controlnet_block_samples, controlnet_single_block_samples = flux_controlnet(
|
||||
hidden_states=noisy_latents,
|
||||
controlnet_cond=control_image,
|
||||
timestep=t,
|
||||
guidance=guidance_vec,
|
||||
pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
|
||||
encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
|
||||
txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
|
||||
img_ids=latent_image_ids[0],
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
noise_pred = flux_transformer(
|
||||
hidden_states=noisy_latents,
|
||||
timestep=t,
|
||||
guidance=guidance_vec,
|
||||
pooled_projections=batch["unet_added_conditions"]["pooled_prompt_embeds"].to(dtype=weight_dtype),
|
||||
encoder_hidden_states=batch["prompt_ids"].to(dtype=weight_dtype),
|
||||
controlnet_block_samples=[sample.to(dtype=weight_dtype) for sample in controlnet_block_samples]
|
||||
if controlnet_block_samples is not None
|
||||
else None,
|
||||
controlnet_single_block_samples=[
|
||||
sample.to(dtype=weight_dtype) for sample in controlnet_single_block_samples
|
||||
]
|
||||
if controlnet_single_block_samples is not None
|
||||
else None,
|
||||
txt_ids=batch["unet_added_conditions"]["time_ids"][0].to(dtype=weight_dtype),
|
||||
img_ids=latent_image_ids[0],
|
||||
return_dict=False,
|
||||
)[0]
|
||||
```
|
||||
Congratulations! You have completed all the required code modifications required for deepspeedzero3.
|
||||
|
||||
### 4.Training with deepspeedzero3
|
||||
|
||||
Start!!!
|
||||
|
||||
```bash
|
||||
export pretrained_model_name_or_path='flux-dev-model-path'
|
||||
export MODEL_TYPE='train_model_type'
|
||||
export TRAIN_JSON_FILE="your_json_file"
|
||||
export CONTROL_TYPE='control_preprocessor_type'
|
||||
export CAPTION_COLUMN='caption_column'
|
||||
|
||||
export CACHE_DIR="/data/train_csr/.cache/huggingface/"
|
||||
export OUTPUT_DIR='/data/train_csr/FLUX/MODEL_OUT/'$MODEL_TYPE
|
||||
# The first step is to use Python to precompute all caches.Replace the first line below with this line. (I am not sure why using acclerate would cause problems.)
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 train_controlnet_flux.py \
|
||||
|
||||
# The second step is to use the above accelerate config to train
|
||||
accelerate launch --config_file "./accelerate_config_zero3.yaml" train_controlnet_flux.py \
|
||||
--pretrained_model_name_or_path=$pretrained_model_name_or_path \
|
||||
--jsonl_for_train=$TRAIN_JSON_FILE \
|
||||
--conditioning_image_column=$CONTROL_TYPE \
|
||||
--image_column=image \
|
||||
--caption_column=$CAPTION_COLUMN\
|
||||
--cache_dir=$CACHE_DIR \
|
||||
--tracker_project_name=$MODEL_TYPE \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--max_train_steps=500000 \
|
||||
--mixed_precision bf16 \
|
||||
--checkpointing_steps=1000 \
|
||||
--gradient_accumulation_steps=8 \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--learning_rate=1e-5 \
|
||||
--num_double_layers=4 \
|
||||
--num_single_layers=0 \
|
||||
--gradient_checkpointing \
|
||||
--resume_from_checkpoint="latest" \
|
||||
# --use_adafactor \ dont use
|
||||
# --validation_steps=3 \ not support
|
||||
# --validation_image $VALIDATION_IMAGE \ not support
|
||||
# --validation_prompt "xxx" \ not support
|
||||
```
|
||||
@@ -1,9 +0,0 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
datasets
|
||||
wandb
|
||||
SentencePiece
|
||||
@@ -136,28 +136,3 @@ class ControlNetSD3(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
|
||||
class ControlNetflux(ExamplesTestsAccelerate):
|
||||
def test_controlnet_flux(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/controlnet/train_controlnet_flux.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-flux-pipe
|
||||
--output_dir={tmpdir}
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--conditioning_image_column=conditioning_image
|
||||
--image_column=image
|
||||
--caption_column=text
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--num_double_layers=1
|
||||
--num_single_layers=1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -49,7 +49,11 @@ from diffusers import (
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
|
||||
from diffusers.training_utils import (
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
)
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -170,8 +174,7 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
|
||||
else:
|
||||
logger.warning(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
free_memory()
|
||||
clear_objs_and_retain_memory(pipeline)
|
||||
|
||||
if not is_final_validation:
|
||||
controlnet.to(accelerator.device)
|
||||
@@ -357,11 +360,6 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upcast_vae",
|
||||
action="store_true",
|
||||
help="Whether or not to upcast vae to fp32",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
@@ -1099,10 +1097,7 @@ def main(args):
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move vae, transformer and text_encoder to device and cast to weight_dtype
|
||||
if args.upcast_vae:
|
||||
vae.to(accelerator.device, dtype=torch.float32)
|
||||
else:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=torch.float32)
|
||||
transformer.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -1136,9 +1131,7 @@ def main(args):
|
||||
new_fingerprint = Hasher.hash(args)
|
||||
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
|
||||
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
del tokenizer_one, tokenizer_two, tokenizer_three
|
||||
free_memory()
|
||||
clear_objs_and_retain_memory(text_encoders + tokenizers)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
|
||||
@@ -136,7 +136,7 @@ accelerate launch train_dreambooth_lora_sd3.py \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--learning_rate=4e-4 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
|
||||
@@ -103,39 +103,6 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
|
||||
)
|
||||
self.assertTrue(starts_with_expected_prefix)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
free_memory,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -985,6 +985,7 @@ def encode_prompt(
|
||||
text_input_ids_list=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
dtype = text_encoders[0].dtype
|
||||
|
||||
pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
@@ -1006,7 +1007,8 @@ def encode_prompt(
|
||||
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
|
||||
)
|
||||
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
@@ -1435,8 +1437,7 @@ def main(args):
|
||||
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
|
||||
free_memory()
|
||||
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1479,8 +1480,7 @@ def main(args):
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
|
||||
if args.validation_prompt is None:
|
||||
del vae
|
||||
free_memory()
|
||||
clear_objs_and_retain_memory([vae])
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -1817,8 +1817,7 @@ def main(args):
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
free_memory()
|
||||
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -55,9 +55,9 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
free_memory,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -140,7 +140,7 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="other",
|
||||
license="openrail++",
|
||||
base_model=base_model,
|
||||
prompt=instance_prompt,
|
||||
model_description=model_description,
|
||||
@@ -186,7 +186,7 @@ def log_validation(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
@@ -211,8 +211,7 @@ def log_validation(
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
free_memory()
|
||||
clear_objs_and_retain_memory(objs=[pipeline])
|
||||
|
||||
return images
|
||||
|
||||
@@ -608,12 +607,6 @@ def parse_args(input_args=None):
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_latents",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Cache the VAE latents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
@@ -634,15 +627,6 @@ def parse_args(input_args=None):
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upcast_before_saving",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
|
||||
"Defaults to precision dtype used for training to save memory"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_generation_precision",
|
||||
type=str,
|
||||
@@ -1122,8 +1106,7 @@ def main(args):
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
free_memory()
|
||||
clear_objs_and_retain_memory(objs=[pipeline])
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1409,16 +1392,6 @@ def main(args):
|
||||
logger.warning(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warning(
|
||||
f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
)
|
||||
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
|
||||
# --learning_rate
|
||||
params_to_optimize[1]["lr"] = args.learning_rate
|
||||
params_to_optimize[2]["lr"] = args.learning_rate
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
@@ -1465,9 +1438,6 @@ def main(args):
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
||||
# the redundant encoding.
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.instance_prompt, text_encoders, tokenizers
|
||||
@@ -1483,9 +1453,9 @@ def main(args):
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
del tokenizers, text_encoders
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
free_memory()
|
||||
clear_objs_and_retain_memory(
|
||||
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
|
||||
)
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1512,21 +1482,6 @@ def main(args):
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
tokens_three = torch.cat([tokens_three, class_tokens_three], dim=0)
|
||||
|
||||
vae_config_shift_factor = vae.config.shift_factor
|
||||
vae_config_scaling_factor = vae.config.scaling_factor
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
accelerator.device, non_blocking=True, dtype=weight_dtype
|
||||
)
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
|
||||
if args.validation_prompt is None:
|
||||
del vae
|
||||
free_memory()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -1543,6 +1498,7 @@ def main(args):
|
||||
power=args.lr_power,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
# Prepare everything with our `accelerator`.
|
||||
if args.train_text_encoder:
|
||||
(
|
||||
@@ -1649,9 +1605,8 @@ def main(args):
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
if args.train_text_encoder:
|
||||
models_to_accumulate.extend([text_encoder_one, text_encoder_two])
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
@@ -1682,13 +1637,8 @@ def main(args):
|
||||
)
|
||||
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
model_input = latents_cache[step].sample()
|
||||
else:
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
|
||||
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
@@ -1821,8 +1771,6 @@ def main(args):
|
||||
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
|
||||
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
|
||||
)
|
||||
text_encoder_one.to(weight_dtype)
|
||||
text_encoder_two.to(weight_dtype)
|
||||
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
@@ -1843,18 +1791,17 @@ def main(args):
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
objs = []
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
free_memory()
|
||||
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
|
||||
|
||||
clear_objs_and_retain_memory(objs=objs)
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer = transformer.to(torch.float32)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
if args.train_text_encoder:
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
@@ -50,7 +51,7 @@ from diffusers import (
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
|
||||
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
is_wandb_available,
|
||||
@@ -118,7 +119,7 @@ Please adhere to the licensing terms as described `[here](https://huggingface.co
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="other",
|
||||
license="openrail++",
|
||||
base_model=base_model,
|
||||
prompt=instance_prompt,
|
||||
model_description=model_description,
|
||||
@@ -163,7 +164,7 @@ def log_validation(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
@@ -189,7 +190,8 @@ def log_validation(
|
||||
)
|
||||
|
||||
del pipeline
|
||||
free_memory()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
@@ -1063,7 +1065,8 @@ def main(args):
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
free_memory()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1383,7 +1386,9 @@ def main(args):
|
||||
del tokenizers, text_encoders
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
free_memory()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1703,9 +1708,6 @@ def main(args):
|
||||
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
|
||||
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
|
||||
)
|
||||
text_encoder_one.to(weight_dtype)
|
||||
text_encoder_two.to(weight_dtype)
|
||||
text_encoder_three.to(weight_dtype)
|
||||
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
@@ -1728,7 +1730,8 @@ def main(args):
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
free_memory()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -747,22 +747,17 @@ def main():
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
@@ -787,14 +782,8 @@ def main():
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!rm -rf dog/.cache"
|
||||
"!rm -rf dog/.huggingface"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -59,8 +59,6 @@ check_min_version("0.31.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
|
||||
torch.npu.config.allow_internal_format = False
|
||||
|
||||
DATASET_NAME_MAPPING = {
|
||||
@@ -542,9 +540,6 @@ def compute_vae_encodings(batch, vae):
|
||||
with torch.no_grad():
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
|
||||
# There might have slightly performance improvement
|
||||
# by changing model_input.cpu() to accelerator.gather(model_input)
|
||||
return {"model_input": model_input.cpu()}
|
||||
|
||||
|
||||
@@ -940,10 +935,7 @@ def main(args):
|
||||
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
|
||||
del text_encoders, tokenizers, vae
|
||||
gc.collect()
|
||||
if is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def collate_fn(examples):
|
||||
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
|
||||
@@ -1099,7 +1091,8 @@ def main(args):
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
target_size = (args.resolution, args.resolution)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
||||
return add_time_ids
|
||||
|
||||
add_time_ids = torch.cat(
|
||||
@@ -1268,10 +1261,7 @@ def main(args):
|
||||
)
|
||||
|
||||
del pipeline
|
||||
if is_torch_npu_available():
|
||||
torch_npu.npu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
"""
|
||||
Convert a CogView3 checkpoint to the Diffusers format.
|
||||
|
||||
This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
|
||||
with the Diffusers library.
|
||||
|
||||
Example usage:
|
||||
python scripts/convert_cogview3_to_diffusers.py \
|
||||
--transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
|
||||
--vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \
|
||||
--output_path "/raid/yiyi/cogview3_diffusers" \
|
||||
--dtype "bf16"
|
||||
|
||||
Arguments:
|
||||
--transformer_checkpoint_path: Path to Transformer state dict.
|
||||
--vae_checkpoint_path: Path to VAE state dict.
|
||||
--output_path: The path to save the converted model.
|
||||
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
|
||||
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
|
||||
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
|
||||
|
||||
Default is "bf16" because CogView3 uses bfloat16 for Training.
|
||||
|
||||
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
|
||||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
TOKENIZER_MAX_LENGTH = 224
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--output_path", required=True, type=str)
|
||||
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
|
||||
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
|
||||
parser.add_argument("--dtype", type=str, default="bf16")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# this is specific to `AdaLayerNormContinuous`:
|
||||
# diffusers implementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
|
||||
def swap_scale_shift(weight, dim):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
|
||||
original_state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
original_state_dict = original_state_dict["module"]
|
||||
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# Convert patch_embed
|
||||
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
|
||||
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
|
||||
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
|
||||
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
|
||||
|
||||
# Convert time_condition_embed
|
||||
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
|
||||
"time_embed.0.weight"
|
||||
)
|
||||
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
|
||||
"time_embed.0.bias"
|
||||
)
|
||||
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
|
||||
"time_embed.2.weight"
|
||||
)
|
||||
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
|
||||
"time_embed.2.bias"
|
||||
)
|
||||
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
|
||||
"label_emb.0.0.weight"
|
||||
)
|
||||
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
|
||||
"label_emb.0.0.bias"
|
||||
)
|
||||
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
|
||||
"label_emb.0.2.weight"
|
||||
)
|
||||
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
|
||||
"label_emb.0.2.bias"
|
||||
)
|
||||
|
||||
# Convert transformer blocks
|
||||
for i in range(30):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
old_prefix = f"transformer.layers.{i}."
|
||||
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
|
||||
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
|
||||
|
||||
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
|
||||
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
|
||||
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
|
||||
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
|
||||
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
|
||||
old_prefix + "attention.dense.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
|
||||
old_prefix + "attention.dense.bias"
|
||||
)
|
||||
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
|
||||
old_prefix + "mlp.dense_h_to_4h.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
|
||||
old_prefix + "mlp.dense_h_to_4h.bias"
|
||||
)
|
||||
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
|
||||
old_prefix + "mlp.dense_4h_to_h.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
|
||||
|
||||
# Convert final norm and projection
|
||||
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
||||
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
|
||||
)
|
||||
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
||||
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
|
||||
)
|
||||
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
|
||||
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
|
||||
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
elif args.dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif args.dtype == "fp32":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {args.dtype}")
|
||||
|
||||
transformer = None
|
||||
vae = None
|
||||
|
||||
if args.transformer_checkpoint_path is not None:
|
||||
converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(
|
||||
args.transformer_checkpoint_path
|
||||
)
|
||||
transformer = CogView3PlusTransformer2DModel()
|
||||
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
|
||||
if dtype is not None:
|
||||
# Original checkpoint data type will be preserved
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
|
||||
if args.vae_checkpoint_path is not None:
|
||||
vae_config = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ("DownEncoderBlock2D",) * 4,
|
||||
"up_block_types": ("UpDecoderBlock2D",) * 4,
|
||||
"block_out_channels": (128, 512, 1024, 1024),
|
||||
"layers_per_block": 3,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 16,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 1.0,
|
||||
"force_upcast": True,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
"mid_block_add_attention": False,
|
||||
}
|
||||
converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
||||
if dtype is not None:
|
||||
vae = vae.to(dtype=dtype)
|
||||
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
|
||||
# Apparently, the conversion does not work anymore without this :shrug:
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = CogVideoXDDIMScheduler.from_config(
|
||||
{
|
||||
"snr_shift_scale": 4.0,
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": False,
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "v_prediction",
|
||||
"rescale_betas_zero_snr": True,
|
||||
"set_alpha_to_one": True,
|
||||
"timestep_spacing": "trailing",
|
||||
}
|
||||
)
|
||||
|
||||
pipe = CogView3PlusPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
|
||||
# save some memory used for model loading.
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
@@ -84,7 +84,6 @@ else:
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ControlNetModel",
|
||||
"ControlNetXSAdapter",
|
||||
@@ -256,11 +255,9 @@ else:
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"CLIPImageProjection",
|
||||
"CogVideoXFunControlPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
"CogVideoXPipeline",
|
||||
"CogVideoXVideoToVideoPipeline",
|
||||
"CogView3PlusPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
"FluxControlNetInpaintPipeline",
|
||||
@@ -331,7 +328,6 @@ else:
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPAGInpaintPipeline",
|
||||
"StableDiffusionControlNetPAGPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
@@ -347,7 +343,6 @@ else:
|
||||
"StableDiffusionLatentUpscalePipeline",
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionModelEditingPipeline",
|
||||
"StableDiffusionPAGImg2ImgPipeline",
|
||||
"StableDiffusionPAGPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionParadigmsPipeline",
|
||||
@@ -562,7 +557,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetModel,
|
||||
ControlNetXSAdapter,
|
||||
@@ -712,11 +706,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
CLIPImageProjection,
|
||||
CogVideoXFunControlPipeline,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
CogVideoXPipeline,
|
||||
CogVideoXVideoToVideoPipeline,
|
||||
CogView3PlusPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
@@ -786,7 +778,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPAGInpaintPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
@@ -802,7 +793,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPAGImg2ImgPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
|
||||
@@ -38,44 +38,16 @@ PipelineImageInput = Union[
|
||||
PipelineDepthInput = PipelineImageInput
|
||||
|
||||
|
||||
def is_valid_image(image) -> bool:
|
||||
r"""
|
||||
Checks if the input is a valid image.
|
||||
|
||||
A valid image can be:
|
||||
- A `PIL.Image.Image`.
|
||||
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
|
||||
|
||||
Args:
|
||||
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
||||
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
`True` if the input is a valid image, `False` otherwise.
|
||||
"""
|
||||
def is_valid_image(image):
|
||||
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
|
||||
|
||||
|
||||
def is_valid_image_imagelist(images):
|
||||
r"""
|
||||
Checks if the input is a valid image or list of images.
|
||||
|
||||
The input can be one of the following formats:
|
||||
- A 4D tensor or numpy array (batch of images).
|
||||
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
|
||||
`torch.Tensor`.
|
||||
- A list of valid images.
|
||||
|
||||
Args:
|
||||
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
|
||||
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
|
||||
images.
|
||||
|
||||
Returns:
|
||||
`bool`:
|
||||
`True` if the input is valid, `False` otherwise.
|
||||
"""
|
||||
# check if the image input is one of the supported formats for image and image list:
|
||||
# it can be either one of below 3
|
||||
# (1) a 4d pytorch tensor or numpy array,
|
||||
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
|
||||
# (3) a list of valid image
|
||||
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
|
||||
return True
|
||||
elif is_valid_image(images):
|
||||
@@ -131,16 +103,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
r"""
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The image array to convert to PIL format.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
A list of PIL images.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
@@ -155,16 +119,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
||||
r"""
|
||||
"""
|
||||
Convert a PIL image or a list of PIL images to NumPy arrays.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
|
||||
The PIL image or list of images to convert to NumPy format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
A NumPy array representation of the images.
|
||||
"""
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
@@ -175,16 +131,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
||||
r"""
|
||||
"""
|
||||
Convert a NumPy image to a PyTorch tensor.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The NumPy image array to convert to PyTorch format.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
A PyTorch tensor representation of the images.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[..., None]
|
||||
@@ -194,62 +142,30 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
||||
r"""
|
||||
"""
|
||||
Convert a PyTorch tensor to a NumPy image.
|
||||
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
The PyTorch tensor to convert to NumPy format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
A NumPy array representation of the images.
|
||||
"""
|
||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
r"""
|
||||
"""
|
||||
Normalize an image array to [-1,1].
|
||||
|
||||
Args:
|
||||
images (`np.ndarray` or `torch.Tensor`):
|
||||
The image array to normalize.
|
||||
|
||||
Returns:
|
||||
`np.ndarray` or `torch.Tensor`:
|
||||
The normalized image array.
|
||||
"""
|
||||
return 2.0 * images - 1.0
|
||||
|
||||
@staticmethod
|
||||
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
r"""
|
||||
"""
|
||||
Denormalize an image array to [0,1].
|
||||
|
||||
Args:
|
||||
images (`np.ndarray` or `torch.Tensor`):
|
||||
The image array to denormalize.
|
||||
|
||||
Returns:
|
||||
`np.ndarray` or `torch.Tensor`:
|
||||
The denormalized image array.
|
||||
"""
|
||||
return (images / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
@staticmethod
|
||||
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
r"""
|
||||
"""
|
||||
Converts a PIL image to RGB format.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The PIL image to convert to RGB.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The RGB-converted PIL image.
|
||||
"""
|
||||
image = image.convert("RGB")
|
||||
|
||||
@@ -257,16 +173,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
r"""
|
||||
Converts a given PIL image to grayscale.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The input image to convert.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The image converted to grayscale.
|
||||
"""
|
||||
Converts a PIL image to grayscale format.
|
||||
"""
|
||||
image = image.convert("L")
|
||||
|
||||
@@ -274,16 +182,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
|
||||
r"""
|
||||
"""
|
||||
Applies Gaussian blur to an image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The PIL image to convert to grayscale.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The grayscale-converted PIL image.
|
||||
"""
|
||||
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
|
||||
|
||||
@@ -291,7 +191,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
@staticmethod
|
||||
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
||||
r"""
|
||||
"""
|
||||
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
|
||||
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
|
||||
processing are 512x512, the region will be expanded to 128x128.
|
||||
@@ -385,21 +285,14 @@ class VaeImageProcessor(ConfigMixin):
|
||||
width: int,
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
r"""
|
||||
"""
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
||||
the image within the dimensions, filling empty with data from image.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image to resize and fill.
|
||||
width (`int`):
|
||||
The width to resize the image to.
|
||||
height (`int`):
|
||||
The height to resize the image to.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The resized and filled image.
|
||||
image: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
"""
|
||||
|
||||
ratio = width / height
|
||||
@@ -437,21 +330,14 @@ class VaeImageProcessor(ConfigMixin):
|
||||
width: int,
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
r"""
|
||||
"""
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
||||
the image within the dimensions, cropping the excess.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`):
|
||||
The image to resize and crop.
|
||||
width (`int`):
|
||||
The width to resize the image to.
|
||||
height (`int`):
|
||||
The height to resize the image to.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The resized and cropped image.
|
||||
image: The image to resize.
|
||||
width: The width to resize the image to.
|
||||
height: The height to resize the image to.
|
||||
"""
|
||||
ratio = width / height
|
||||
src_ratio = image.width / image.height
|
||||
@@ -543,23 +429,19 @@ class VaeImageProcessor(ConfigMixin):
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
|
||||
"""
|
||||
This function return the height and width that are downscaled to the next integer multiple of
|
||||
`vae_scale_factor`.
|
||||
|
||||
Args:
|
||||
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
||||
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
|
||||
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
|
||||
tensor, it should have shape `[batch, channels, height, width]`.
|
||||
height (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
|
||||
width (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`:
|
||||
A tuple containing the height and width, both resized to the nearest integer multiple of
|
||||
`vae_scale_factor`.
|
||||
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
||||
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
|
||||
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
|
||||
have shape `[batch, channel, height, width]`.
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height in preprocessed image. If `None`, will use the height of `image` input.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width in preprocessed. If `None`, will use the width of the `image` input.
|
||||
"""
|
||||
|
||||
if height is None:
|
||||
@@ -596,13 +478,13 @@ class VaeImageProcessor(ConfigMixin):
|
||||
Preprocess the image input.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
image (`pipeline_image_input`):
|
||||
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
||||
supported formats.
|
||||
height (`int`, *optional*):
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
||||
height.
|
||||
width (`int`, *optional*):
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
||||
resize_mode (`str`, *optional*, defaults to `default`):
|
||||
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
||||
@@ -614,10 +496,6 @@ class VaeImageProcessor(ConfigMixin):
|
||||
supported for PIL image input.
|
||||
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The preprocessed image.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
|
||||
@@ -777,22 +655,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
image: PIL.Image.Image,
|
||||
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
||||
) -> PIL.Image.Image:
|
||||
r"""
|
||||
Applies an overlay of the mask and the inpainted image on the original image.
|
||||
|
||||
Args:
|
||||
mask (`PIL.Image.Image`):
|
||||
The mask image that highlights regions to overlay.
|
||||
init_image (`PIL.Image.Image`):
|
||||
The original image to which the overlay is applied.
|
||||
image (`PIL.Image.Image`):
|
||||
The image to overlay onto the original.
|
||||
crop_coords (`Tuple[int, int, int, int]`, *optional*):
|
||||
Coordinates to crop the image. If provided, the image will be cropped accordingly.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`:
|
||||
The final image with the overlay applied.
|
||||
"""
|
||||
overlay the inpaint output to the original image
|
||||
"""
|
||||
|
||||
width, height = image.width, image.height
|
||||
@@ -849,16 +713,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
r"""
|
||||
Convert a NumPy image or a batch of images to a list of PIL images.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The input NumPy array of images, which can be a single image or a batch.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
A list of PIL images converted from the input NumPy array.
|
||||
"""
|
||||
Convert a NumPy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
@@ -873,16 +729,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
|
||||
@staticmethod
|
||||
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
||||
r"""
|
||||
"""
|
||||
Convert a PIL image or a list of PIL images to NumPy arrays.
|
||||
|
||||
Args:
|
||||
images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
|
||||
The input image or list of images to be converted.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
A NumPy array of the converted images.
|
||||
"""
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
@@ -893,30 +741,18 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
|
||||
@staticmethod
|
||||
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
r"""
|
||||
Convert an RGB-like depth image to a depth map.
|
||||
|
||||
"""
|
||||
Args:
|
||||
image (`Union[np.ndarray, torch.Tensor]`):
|
||||
The RGB-like depth image to convert.
|
||||
image: RGB-like depth image
|
||||
|
||||
Returns: depth map
|
||||
|
||||
Returns:
|
||||
`Union[np.ndarray, torch.Tensor]`:
|
||||
The corresponding depth map.
|
||||
"""
|
||||
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
||||
|
||||
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
r"""
|
||||
Convert a NumPy depth image or a batch of images to a list of PIL images.
|
||||
|
||||
Args:
|
||||
images (`np.ndarray`):
|
||||
The input NumPy array of depth images, which can be a single image or a batch.
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`:
|
||||
A list of PIL images converted from the input NumPy depth images.
|
||||
"""
|
||||
Convert a NumPy depth image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
@@ -997,24 +833,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
width: Optional[int] = None,
|
||||
target_res: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors.
|
||||
|
||||
Args:
|
||||
rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
|
||||
The RGB input image, which can be a single image or a batch.
|
||||
depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
|
||||
The depth input image, which can be a single image or a batch.
|
||||
height (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The desired height of the processed image. If `None`, defaults to the height of the input image.
|
||||
width (`Optional[int]`, *optional*, defaults to `None`):
|
||||
The desired width of the processed image. If `None`, defaults to the width of the input image.
|
||||
target_res (`Optional[int]`, *optional*, defaults to `None`):
|
||||
Target resolution for resizing the images. If specified, overrides height and width.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing the processed RGB and depth images as PyTorch tensors.
|
||||
"""
|
||||
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
||||
"""
|
||||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
||||
|
||||
@@ -1252,17 +1072,7 @@ class PixArtImageProcessor(VaeImageProcessor):
|
||||
|
||||
@staticmethod
|
||||
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the binned height and width based on the aspect ratio.
|
||||
|
||||
Args:
|
||||
height (`int`): The height of the image.
|
||||
width (`int`): The width of the image.
|
||||
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
|
||||
|
||||
Returns:
|
||||
`Tuple[int, int]`: The closest binned height and width.
|
||||
"""
|
||||
"""Returns binned height and width."""
|
||||
ar = float(height / width)
|
||||
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
||||
default_hw = ratios[closest_ratio]
|
||||
@@ -1270,19 +1080,6 @@ class PixArtImageProcessor(VaeImageProcessor):
|
||||
|
||||
@staticmethod
|
||||
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
|
||||
r"""
|
||||
Resizes and crops a tensor of images to the specified dimensions.
|
||||
|
||||
Args:
|
||||
samples (`torch.Tensor`):
|
||||
A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height,
|
||||
and W is the width.
|
||||
new_width (`int`): The desired width of the output images.
|
||||
new_height (`int`): The desired height of the output images.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A tensor containing the resized and cropped images.
|
||||
"""
|
||||
orig_height, orig_width = samples.shape[2], samples.shape[3]
|
||||
|
||||
# Check if resizing is needed
|
||||
|
||||
@@ -532,19 +532,13 @@ class LoraBaseMixin:
|
||||
)
|
||||
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
# eg ["adapter1", "adapter2"]
|
||||
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||
missing_adapters = set(adapter_names) - all_adapters
|
||||
if len(missing_adapters) > 0:
|
||||
raise ValueError(
|
||||
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
||||
)
|
||||
|
||||
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
all_adapters = {
|
||||
adapter for adapters in list_adapters.values() for adapter in adapters
|
||||
} # eg ["adapter1", "adapter2"]
|
||||
invert_list_adapters = {
|
||||
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
||||
for adapter in all_adapters
|
||||
}
|
||||
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
|
||||
# Decompose weights into weights for denoiser and text encoders.
|
||||
_component_adapter_weights = {}
|
||||
|
||||
@@ -516,47 +516,10 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
f"transformer.single_transformer_blocks.{i}.norm.linear",
|
||||
)
|
||||
|
||||
remaining_keys = list(sds_sd.keys())
|
||||
te_state_dict = {}
|
||||
if remaining_keys:
|
||||
if not all(k.startswith("lora_te1") for k in remaining_keys):
|
||||
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
|
||||
for key in remaining_keys:
|
||||
if not key.endswith("lora_down.weight"):
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
lora_name_up = f"{lora_name}.lora_up.weight"
|
||||
lora_name_alpha = f"{lora_name}.alpha"
|
||||
diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
|
||||
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
down_weight = sds_sd.pop(key)
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
te_state_dict[diffusers_name] = down_weight
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = sds_sd.pop(lora_name_up)
|
||||
|
||||
if lora_name_alpha in sds_sd:
|
||||
alpha = sds_sd.pop(lora_name_alpha).item()
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
te_state_dict[diffusers_name] *= scale_down
|
||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] *= scale_up
|
||||
|
||||
if len(sds_sd) > 0:
|
||||
logger.warning(f"Unsupported keys for ai-toolkit: {sds_sd.keys()}")
|
||||
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
|
||||
|
||||
if te_state_dict:
|
||||
te_state_dict = {f"text_encoder.{module_name}": params for module_name, params in te_state_dict.items()}
|
||||
|
||||
new_state_dict = {**ait_sd, **te_state_dict}
|
||||
return new_state_dict
|
||||
return ait_sd
|
||||
|
||||
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
||||
|
||||
@@ -632,7 +595,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
new_key += ".lora_B.weight"
|
||||
|
||||
# Handle single_blocks
|
||||
elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
|
||||
elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"):
|
||||
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
|
||||
new_key = f"transformer.single_transformer_blocks.{block_num}"
|
||||
|
||||
|
||||
@@ -25,11 +25,8 @@ from ..utils import (
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
)
|
||||
@@ -42,17 +39,6 @@ from .lora_conversion_utils import (
|
||||
)
|
||||
|
||||
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
|
||||
if is_torch_version(">=", "1.9.0"):
|
||||
if (
|
||||
is_peft_available()
|
||||
and is_peft_version(">=", "0.13.1")
|
||||
and is_transformers_available()
|
||||
and is_transformers_version(">", "4.45.2")
|
||||
):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
@@ -97,24 +83,15 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
@@ -122,7 +99,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
@@ -132,7 +109,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
@@ -143,7 +119,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -236,11 +211,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
network_alphas = None
|
||||
# TODO: replace it with a method from `state_dict_utils`
|
||||
@@ -262,9 +232,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
return state_dict, network_alphas
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -282,16 +250,10 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
@@ -301,11 +263,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -318,7 +276,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -341,25 +298,10 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -430,7 +372,6 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
@@ -601,19 +542,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
@@ -628,18 +562,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
|
||||
)
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
@@ -651,7 +579,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
@@ -664,7 +591,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -758,11 +684,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
network_alphas = None
|
||||
# TODO: replace it with a method from `state_dict_utils`
|
||||
@@ -785,9 +706,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -805,16 +724,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
@@ -824,11 +737,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -842,7 +751,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -865,25 +773,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -954,7 +847,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
@@ -1197,12 +1089,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
@@ -1223,22 +1109,15 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
@@ -1246,7 +1125,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
@@ -1255,7 +1134,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
@@ -1268,7 +1146,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
@@ -1281,13 +1158,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -1301,13 +1175,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
@@ -1351,37 +1219,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
@@ -1401,7 +1249,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -1424,25 +1271,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -1513,7 +1345,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
@@ -1756,13 +1587,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
|
||||
|
||||
is_kohya = any(".lora_down.weight" in k for k in state_dict)
|
||||
if is_kohya:
|
||||
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
|
||||
@@ -1819,17 +1646,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
@@ -1839,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
|
||||
)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
@@ -1849,7 +1669,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
@@ -1862,13 +1681,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -1886,13 +1702,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
@@ -1941,37 +1751,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
@@ -1991,7 +1781,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -2014,25 +1803,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -2103,7 +1877,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
@@ -2311,30 +2084,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
@@ -2354,7 +2111,6 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -2377,25 +2133,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -2466,7 +2207,6 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
@@ -2634,12 +2374,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
@@ -2655,22 +2389,15 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
@@ -2678,7 +2405,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
@@ -2687,14 +2414,11 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -2708,13 +2432,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.:
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
@@ -2758,37 +2476,17 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
|
||||
@@ -1513,7 +1513,7 @@ def _legacy_load_scheduler(
|
||||
)
|
||||
deprecate("prediction_type", "1.0.0", deprecation_message)
|
||||
|
||||
scheduler_config = SCHEDULER_DEFAULT_CONFIG
|
||||
scheduler_config = copy.deepcopy(SCHEDULER_DEFAULT_CONFIG)
|
||||
model_type = infer_diffusers_model_type(checkpoint=checkpoint)
|
||||
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
|
||||
@@ -115,9 +115,6 @@ class UNet2DConditionLoadersMixin:
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -145,14 +142,8 @@ class UNet2DConditionLoadersMixin:
|
||||
adapter_name = kwargs.pop("adapter_name", None)
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
||||
allow_pickle = False
|
||||
|
||||
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
@@ -218,7 +209,6 @@ class UNet2DConditionLoadersMixin:
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -278,9 +268,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
return attn_processors
|
||||
|
||||
def _process_lora(
|
||||
self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
|
||||
):
|
||||
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
|
||||
# This method does the following things:
|
||||
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
|
||||
# format. For legacy format no filtering is applied.
|
||||
@@ -347,37 +335,18 @@ class UNet2DConditionLoadersMixin:
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
# Check only for unexpected keys.
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_unexpected_keys:
|
||||
warn_msg = (
|
||||
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
||||
f" {', '.join(lora_unexpected_keys)}. "
|
||||
)
|
||||
|
||||
# Filter missing keys specific to the current adapter.
|
||||
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
||||
if missing_keys:
|
||||
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
||||
if lora_missing_keys:
|
||||
warn_msg += (
|
||||
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
||||
f" {', '.join(lora_missing_keys)}."
|
||||
)
|
||||
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||
|
||||
|
||||
@@ -54,7 +54,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
@@ -99,7 +98,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .transformers import (
|
||||
AuraFlowTransformer2DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
|
||||
@@ -30,10 +30,10 @@ class MultiAdapter(ModelMixin):
|
||||
MultiAdapter is a wrapper model that contains multiple adapter models and merges their outputs according to
|
||||
user-assigned weighting.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for common methods such as downloading
|
||||
or saving.
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the model (such as downloading or saving, etc.)
|
||||
|
||||
Args:
|
||||
Parameters:
|
||||
adapters (`List[T2IAdapter]`, *optional*, defaults to None):
|
||||
A list of `T2IAdapter` model instances.
|
||||
"""
|
||||
@@ -77,13 +77,11 @@ class MultiAdapter(ModelMixin):
|
||||
r"""
|
||||
Args:
|
||||
xs (`torch.Tensor`):
|
||||
A tensor of shape (batch, channel, height, width) representing input images for multiple adapter
|
||||
models, concatenated along dimension 1(channel dimension). The `channel` dimension should be equal to
|
||||
`num_adapter` * number of channel per image.
|
||||
|
||||
(batch, channel, height, width) input images for multiple adapter models concated along dimension 1,
|
||||
`channel` should equal to `num_adapter` * "number of channel of image".
|
||||
adapter_weights (`List[float]`, *optional*, defaults to None):
|
||||
A list of floats representing the weights which will be multiplied by each adapter's output before
|
||||
summing them together. If `None`, equal weights will be used for all adapters.
|
||||
List of floats representing the weight which will be multiply to each adapter's output before adding
|
||||
them together.
|
||||
"""
|
||||
if adapter_weights is None:
|
||||
adapter_weights = torch.tensor([1 / self.num_adapter] * self.num_adapter)
|
||||
@@ -111,24 +109,24 @@ class MultiAdapter(ModelMixin):
|
||||
variant: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a specified directory, allowing it to be re-loaded with the
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`[`~models.adapter.MultiAdapter.from_pretrained`]` class method.
|
||||
|
||||
Args:
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
The directory where the model will be saved. If the directory does not exist, it will be created.
|
||||
is_main_process (`bool`, optional, defaults=True):
|
||||
Indicates whether current process is the main process or not. Useful for distributed training (e.g.,
|
||||
TPUs) and need to call this function on all processes. In this case, set `is_main_process=True` only
|
||||
for the main process to avoid race conditions.
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful when in distributed training like
|
||||
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
||||
the main process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
Function used to save the state dictionary. Useful for distributed training (e.g., TPUs) to replace
|
||||
`torch.save` with another method. Can also be configured using`DIFFUSERS_SAVE_MODE` environment
|
||||
variable.
|
||||
safe_serialization (`bool`, optional, defaults=True):
|
||||
If `True`, save the model using `safetensors`. If `False`, save the model with `pickle`.
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
||||
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
||||
"""
|
||||
idx = 0
|
||||
model_path_to_save = save_directory
|
||||
@@ -147,17 +145,19 @@ class MultiAdapter(ModelMixin):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained `MultiAdapter` model from multiple pre-trained adapter models.
|
||||
Instantiate a pretrained MultiAdapter model from multiple pre-trained adapter models.
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
||||
the model, set it back to training mode using `model.train()`.
|
||||
the model, you should first set it back in training mode with `model.train()`.
|
||||
|
||||
Warnings:
|
||||
*Weights from XXX not initialized from pretrained model* means that the weights of XXX are not pretrained
|
||||
with the rest of the model. It is up to you to train those weights with a downstream fine-tuning. *Weights
|
||||
from XXX not used in YYY* means that the layer XXX is not used by YYY, so those weights are discarded.
|
||||
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
Args:
|
||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_path (`os.PathLike`):
|
||||
A path to a *directory* containing model weights saved using
|
||||
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
|
||||
@@ -175,20 +175,20 @@ class MultiAdapter(ModelMixin):
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary mapping device identifiers to their maximum memory. Default to the maximum memory
|
||||
available for each GPU and the available CPU RAM if unset.
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
||||
GPU and the available CPU RAM if unset.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
variant (`str`, *optional*):
|
||||
If specified, load weights from a `variant` file (*e.g.* pytorch_model.<variant>.bin). `variant` will
|
||||
be ignored when using `from_flax`.
|
||||
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
||||
ignored when using `from_flax`.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If `None`, the `safetensors` weights will be downloaded if available **and** if`safetensors` library is
|
||||
installed. If `True`, the model will be forcibly loaded from`safetensors` weights. If `False`,
|
||||
`safetensors` is not used.
|
||||
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
||||
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
||||
"""
|
||||
idx = 0
|
||||
adapters = []
|
||||
@@ -223,22 +223,22 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
and
|
||||
[AdapterLight](https://github.com/TencentARC/T2I-Adapter/blob/686de4681515662c0ac2ffa07bf5dda83af1038a/ldm/modules/encoders/adapter.py#L235).
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the common methods, such as
|
||||
downloading or saving.
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the model (such as downloading or saving, etc.)
|
||||
|
||||
Args:
|
||||
in_channels (`int`, *optional*, defaults to `3`):
|
||||
The number of channels in the adapter's input (*control image*). Set it to 1 if you're using a gray scale
|
||||
image.
|
||||
Parameters:
|
||||
in_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels of Aapter's input(*control image*). Set this parameter to 1 if you're using gray scale
|
||||
image as *control image*.
|
||||
channels (`List[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The number of channels in each downsample block's output hidden state. The `len(block_out_channels)`
|
||||
determines the number of downsample blocks in the adapter.
|
||||
num_res_blocks (`int`, *optional*, defaults to `2`):
|
||||
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
|
||||
also determine the number of downsample blocks in the Adapter.
|
||||
num_res_blocks (`int`, *optional*, defaults to 2):
|
||||
Number of ResNet blocks in each downsample block.
|
||||
downscale_factor (`int`, *optional*, defaults to `8`):
|
||||
downscale_factor (`int`, *optional*, defaults to 8):
|
||||
A factor that determines the total downscale factor of the Adapter.
|
||||
adapter_type (`str`, *optional*, defaults to `full_adapter`):
|
||||
Adapter type (`full_adapter` or `full_adapter_xl` or `light_adapter`) to use.
|
||||
The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -393,7 +393,7 @@ class AdapterBlock(nn.Module):
|
||||
An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
|
||||
`FullAdapterXL` models.
|
||||
|
||||
Args:
|
||||
Parameters:
|
||||
in_channels (`int`):
|
||||
Number of channels of AdapterBlock's input.
|
||||
out_channels (`int`):
|
||||
@@ -401,7 +401,7 @@ class AdapterBlock(nn.Module):
|
||||
num_res_blocks (`int`):
|
||||
Number of ResNet blocks in the AdapterBlock.
|
||||
down (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, perform downsampling on AdapterBlock's input.
|
||||
Whether to perform downsampling on AdapterBlock's input.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
|
||||
@@ -440,7 +440,7 @@ class AdapterResnetBlock(nn.Module):
|
||||
r"""
|
||||
An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
|
||||
|
||||
Args:
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
Number of channels of AdapterResnetBlock's input and output.
|
||||
"""
|
||||
@@ -518,7 +518,7 @@ class LightAdapterBlock(nn.Module):
|
||||
A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
|
||||
`LightAdapter` model.
|
||||
|
||||
Args:
|
||||
Parameters:
|
||||
in_channels (`int`):
|
||||
Number of channels of LightAdapterBlock's input.
|
||||
out_channels (`int`):
|
||||
@@ -526,7 +526,7 @@ class LightAdapterBlock(nn.Module):
|
||||
num_res_blocks (`int`):
|
||||
Number of LightAdapterResnetBlocks in the LightAdapterBlock.
|
||||
down (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, perform downsampling on LightAdapterBlock's input.
|
||||
Whether to perform downsampling on LightAdapterBlock's input.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
|
||||
@@ -561,7 +561,7 @@ class LightAdapterResnetBlock(nn.Module):
|
||||
A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
|
||||
architecture than `AdapterResnetBlock`.
|
||||
|
||||
Args:
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
Number of channels of LightAdapterResnetBlock's input and output.
|
||||
"""
|
||||
|
||||
@@ -122,7 +122,6 @@ class Attention(nn.Module):
|
||||
out_dim: int = None,
|
||||
context_pre_only=None,
|
||||
pre_only=False,
|
||||
elementwise_affine: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -180,8 +179,8 @@ class Attention(nn.Module):
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
elif qk_norm == "layer_norm":
|
||||
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
|
||||
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
|
||||
elif qk_norm == "fp32_layer_norm":
|
||||
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
|
||||
@@ -18,7 +18,6 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import deprecate
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -246,18 +245,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||
return self._tiled_encode(x)
|
||||
|
||||
enc = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
enc = self.quant_conv(enc)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
@@ -274,13 +261,21 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||
return self.tiled_encode(x, return_dict=return_dict)
|
||||
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
h = self.encoder(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
if self.quant_conv is not None:
|
||||
moments = self.quant_conv(h)
|
||||
else:
|
||||
moments = h
|
||||
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
@@ -342,54 +337,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
if self.config.use_quant_conv:
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
enc = torch.cat(result_rows, dim=2)
|
||||
return enc
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
@@ -409,13 +356,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
||||
`tuple` is returned.
|
||||
"""
|
||||
deprecation_message = (
|
||||
"The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
|
||||
"implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
|
||||
"to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
|
||||
)
|
||||
deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -41,9 +41,7 @@ class CogVideoXSafeConv3d(nn.Conv3d):
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
memory_count = (
|
||||
(input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
|
||||
)
|
||||
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
|
||||
|
||||
# Set to 2GB, suitable for CuDNN
|
||||
if memory_count > 2:
|
||||
@@ -117,24 +115,34 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
dilation=dilation,
|
||||
)
|
||||
|
||||
def fake_context_parallel_forward(
|
||||
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
self.conv_cache = None
|
||||
|
||||
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
kernel_size = self.time_kernel_size
|
||||
if kernel_size > 1:
|
||||
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
cached_inputs = (
|
||||
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
)
|
||||
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
||||
return inputs
|
||||
|
||||
def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
|
||||
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
del self.conv_cache
|
||||
self.conv_cache = None
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
|
||||
# hundred megabytes and so let's not do it for now
|
||||
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||
|
||||
output = self.conv(inputs)
|
||||
return output, conv_cache
|
||||
return output
|
||||
|
||||
|
||||
class CogVideoXSpatialNorm3D(nn.Module):
|
||||
@@ -164,12 +172,7 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(
|
||||
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
@@ -180,12 +183,9 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
else:
|
||||
zq = F.interpolate(zq, size=f.shape[-3:])
|
||||
|
||||
conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
||||
conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
||||
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * conv_y + conv_b
|
||||
return new_f, new_conv_cache
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
class CogVideoXResnetBlock3D(nn.Module):
|
||||
@@ -236,7 +236,6 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
self.out_channels = out_channels
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.spatial_norm_dim = spatial_norm_dim
|
||||
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
@@ -280,43 +279,34 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
inputs: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
hidden_states = inputs
|
||||
|
||||
if zq is not None:
|
||||
hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
|
||||
hidden_states = self.norm1(hidden_states, zq)
|
||||
else:
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if temb is not None:
|
||||
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is not None:
|
||||
hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
|
||||
hidden_states = self.norm2(hidden_states, zq)
|
||||
else:
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
|
||||
inputs, conv_cache=conv_cache.get("conv_shortcut")
|
||||
)
|
||||
else:
|
||||
inputs = self.conv_shortcut(inputs)
|
||||
inputs = self.conv_shortcut(inputs)
|
||||
|
||||
hidden_states = hidden_states + inputs
|
||||
return hidden_states, new_conv_cache
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXDownBlock3D(nn.Module):
|
||||
@@ -402,16 +392,8 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXDownBlock3D` class."""
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -420,23 +402,17 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states, new_conv_cache
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXMidBlock3D(nn.Module):
|
||||
@@ -504,16 +480,8 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXMidBlock3D` class."""
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -522,15 +490,13 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
|
||||
return hidden_states, new_conv_cache
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXUpBlock3D(nn.Module):
|
||||
@@ -618,16 +584,9 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
zq: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Forward method of the `CogVideoXUpBlock3D` class."""
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
@@ -636,23 +595,17 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
hidden_states = resnet(hidden_states, temb, zq)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
|
||||
return hidden_states, new_conv_cache
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXEncoder3D(nn.Module):
|
||||
@@ -752,18 +705,9 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXEncoder3D` class."""
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -774,44 +718,28 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
# 1. Down
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block), hidden_states, temb, None
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
conv_cache=conv_cache.get("mid_block"),
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, None
|
||||
)
|
||||
else:
|
||||
# 1. Down
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
||||
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states, temb, None)
|
||||
|
||||
# 2. Mid
|
||||
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
||||
hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
|
||||
)
|
||||
hidden_states = self.mid_block(hidden_states, temb, None)
|
||||
|
||||
# 3. Post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
|
||||
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return hidden_states, new_conv_cache
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXDecoder3D(nn.Module):
|
||||
@@ -918,18 +846,9 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
conv_cache: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
r"""The forward method of the `CogVideoXDecoder3D` class."""
|
||||
|
||||
new_conv_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -940,45 +859,28 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
# 1. Mid
|
||||
hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
conv_cache=conv_cache.get("mid_block"),
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), hidden_states, temb, sample
|
||||
)
|
||||
|
||||
# 2. Up
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
conv_cache_key = f"up_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), hidden_states, temb, sample
|
||||
)
|
||||
else:
|
||||
# 1. Mid
|
||||
hidden_states, new_conv_cache["mid_block"] = self.mid_block(
|
||||
hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
|
||||
)
|
||||
hidden_states = self.mid_block(hidden_states, temb, sample)
|
||||
|
||||
# 2. Up
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
conv_cache_key = f"up_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = up_block(
|
||||
hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
|
||||
)
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states, temb, sample)
|
||||
|
||||
# 3. Post-process
|
||||
hidden_states, new_conv_cache["norm_out"] = self.norm_out(
|
||||
hidden_states, sample, conv_cache=conv_cache.get("norm_out")
|
||||
)
|
||||
hidden_states = self.norm_out(hidden_states, sample)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return hidden_states, new_conv_cache
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
@@ -1117,6 +1019,12 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CogVideoXCausalConv3d):
|
||||
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
||||
module._clear_fake_context_parallel_cache()
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
@@ -1182,22 +1090,21 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
|
||||
frame_batch_size = self.num_sample_frames_batch_size
|
||||
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
||||
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
||||
num_batches = max(num_frames // frame_batch_size, 1)
|
||||
conv_cache = None
|
||||
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
|
||||
enc = []
|
||||
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
x_intermediate = x[:, :, start_frame:end_frame]
|
||||
x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
|
||||
x_intermediate = self.encoder(x_intermediate)
|
||||
if self.quant_conv is not None:
|
||||
x_intermediate = self.quant_conv(x_intermediate)
|
||||
enc.append(x_intermediate)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
enc = torch.cat(enc, dim=2)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
@@ -1235,10 +1142,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
num_batches = max(num_frames // frame_batch_size, 1)
|
||||
conv_cache = None
|
||||
num_batches = num_frames // frame_batch_size
|
||||
dec = []
|
||||
|
||||
for i in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
@@ -1246,9 +1151,10 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z_intermediate = z[:, :, start_frame:end_frame]
|
||||
if self.post_quant_conv is not None:
|
||||
z_intermediate = self.post_quant_conv(z_intermediate)
|
||||
z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
|
||||
z_intermediate = self.decoder(z_intermediate)
|
||||
dec.append(z_intermediate)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
dec = torch.cat(dec, dim=2)
|
||||
|
||||
if not return_dict:
|
||||
@@ -1331,11 +1237,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
|
||||
# As the extra single frame is handled inside the loop, it is not required to round up here.
|
||||
num_batches = max(num_frames // frame_batch_size, 1)
|
||||
conv_cache = None
|
||||
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
|
||||
time = []
|
||||
|
||||
for k in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
@@ -1347,11 +1250,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
i : i + self.tile_sample_min_height,
|
||||
j : j + self.tile_sample_min_width,
|
||||
]
|
||||
tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
|
||||
tile = self.encoder(tile)
|
||||
if self.quant_conv is not None:
|
||||
tile = self.quant_conv(tile)
|
||||
time.append(tile)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
@@ -1411,10 +1314,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
num_batches = max(num_frames // frame_batch_size, 1)
|
||||
conv_cache = None
|
||||
num_batches = num_frames // frame_batch_size
|
||||
time = []
|
||||
|
||||
for k in range(num_batches):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
@@ -1428,9 +1329,9 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
]
|
||||
if self.post_quant_conv is not None:
|
||||
tile = self.post_quant_conv(tile)
|
||||
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
|
||||
tile = self.decoder(tile)
|
||||
time.append(tile)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..loaders import PeftAdapterMixin
|
||||
from ..models.attention_processor import AttentionProcessor
|
||||
from ..models.modeling_utils import ModelMixin
|
||||
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
|
||||
from .controlnet import BaseOutput, zero_module
|
||||
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from .modeling_outputs import Transformer2DModelOutput
|
||||
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
@@ -55,7 +55,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
@@ -107,14 +106,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
if self.union:
|
||||
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
|
||||
|
||||
if conditioning_embedding_channels is not None:
|
||||
self.input_hint_block = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
|
||||
)
|
||||
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
||||
else:
|
||||
self.input_hint_block = None
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@@ -277,16 +269,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
if self.input_hint_block is not None:
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
|
||||
height = height_pw // self.config.patch_size
|
||||
width = width_pw // self.config.patch_size
|
||||
controlnet_cond = controlnet_cond.reshape(
|
||||
batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
|
||||
)
|
||||
controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
|
||||
controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
|
||||
# add
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
@@ -520,17 +502,16 @@ class FluxMultiControlNetModel(ModelMixin):
|
||||
control_block_samples = block_samples
|
||||
control_single_block_samples = single_block_samples
|
||||
else:
|
||||
if block_samples is not None and control_block_samples is not None:
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
if single_block_samples is not None and control_single_block_samples is not None:
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
control_block_samples = [
|
||||
control_block_sample + block_sample
|
||||
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
||||
]
|
||||
|
||||
control_single_block_samples = [
|
||||
control_single_block_sample + block_sample
|
||||
for control_single_block_sample, block_sample in zip(
|
||||
control_single_block_samples, single_block_samples
|
||||
)
|
||||
]
|
||||
|
||||
return control_block_samples, control_single_block_samples
|
||||
|
||||
@@ -442,60 +442,6 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
return embeds
|
||||
|
||||
|
||||
class CogView3PlusPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
hidden_size: int = 2560,
|
||||
patch_size: int = 2,
|
||||
text_hidden_size: int = 4096,
|
||||
pos_embed_max_size: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.text_hidden_size = text_hidden_size
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
# Linear projection for image patches
|
||||
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
|
||||
|
||||
# Linear projection for text embeddings
|
||||
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
|
||||
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
|
||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
|
||||
if height % self.patch_size != 0 or width % self.patch_size != 0:
|
||||
raise ValueError("Height and width must be divisible by patch size")
|
||||
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
|
||||
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
|
||||
|
||||
# Project the patches
|
||||
hidden_states = self.proj(hidden_states)
|
||||
encoder_hidden_states = self.text_proj(encoder_hidden_states)
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# Calculate text_length
|
||||
text_length = encoder_hidden_states.shape[1]
|
||||
|
||||
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
|
||||
text_pos_embed = torch.zeros(
|
||||
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
|
||||
)
|
||||
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
|
||||
|
||||
return (hidden_states + pos_embed).to(hidden_states.dtype)
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
@@ -1134,39 +1080,6 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
||||
return conditioning
|
||||
|
||||
|
||||
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
|
||||
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
original_size: torch.Tensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
hidden_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
|
||||
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
|
||||
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
|
||||
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
|
||||
|
||||
# (B, 3 * condition_dim)
|
||||
condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
|
||||
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
||||
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
||||
|
||||
conditioning = timesteps_emb + condition_emb
|
||||
return conditioning
|
||||
|
||||
|
||||
class HunyuanDiTAttentionPool(nn.Module):
|
||||
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
||||
|
||||
|
||||
@@ -29,21 +29,11 @@ def get_sinusoidal_embeddings(
|
||||
"""Returns the positional encoding (same as Tensor2Tensor).
|
||||
|
||||
Args:
|
||||
timesteps (`jnp.ndarray` of shape `(N,)`):
|
||||
A 1-D array of N indices, one per batch element. These may be fractional.
|
||||
embedding_dim (`int`):
|
||||
The number of output channels.
|
||||
freq_shift (`float`, *optional*, defaults to `1`):
|
||||
Shift applied to the frequency scaling of the embeddings.
|
||||
min_timescale (`float`, *optional*, defaults to `1`):
|
||||
The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
|
||||
max_timescale (`float`, *optional*, defaults to `1.0e4`):
|
||||
The largest time unit used in the sinusoidal calculation.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the order of sinusoidal components to cosine first.
|
||||
scale (`float`, *optional*, defaults to `1.0`):
|
||||
A scaling factor applied to the positional embeddings.
|
||||
|
||||
timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
embedding_dim: The number of output channels.
|
||||
min_timescale: The smallest time unit (should probably be 0.0).
|
||||
max_timescale: The largest time unit.
|
||||
Returns:
|
||||
a Tensor of timing signals [N, num_channels]
|
||||
"""
|
||||
@@ -71,9 +61,9 @@ class FlaxTimestepEmbedding(nn.Module):
|
||||
|
||||
Args:
|
||||
time_embed_dim (`int`, *optional*, defaults to `32`):
|
||||
Time step embedding dimension.
|
||||
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
||||
The data type for the embedding parameters.
|
||||
Time step embedding dimension
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
|
||||
time_embed_dim: int = 32
|
||||
@@ -93,11 +83,7 @@ class FlaxTimesteps(nn.Module):
|
||||
|
||||
Args:
|
||||
dim (`int`, *optional*, defaults to `32`):
|
||||
Time step embedding dimension.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
Whether to flip the sinusoidal function from sine to cosine.
|
||||
freq_shift (`float`, *optional*, defaults to `1`):
|
||||
Frequency shift applied to the sinusoidal embeddings.
|
||||
Time step embedding dimension
|
||||
"""
|
||||
|
||||
dim: int = 32
|
||||
|
||||
@@ -31,7 +31,6 @@ from ..utils import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
_add_variant,
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
@@ -229,67 +228,3 @@ def _fetch_index_file(
|
||||
index_file = None
|
||||
|
||||
return index_file
|
||||
|
||||
|
||||
def _fetch_index_file_legacy(
|
||||
is_local,
|
||||
pretrained_model_name_or_path,
|
||||
subfolder,
|
||||
use_safetensors,
|
||||
cache_dir,
|
||||
variant,
|
||||
force_download,
|
||||
proxies,
|
||||
local_files_only,
|
||||
token,
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder or "",
|
||||
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
|
||||
).as_posix()
|
||||
splits = index_file.split(".")
|
||||
split_index = -3 if ".cache" in index_file else -2
|
||||
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
||||
index_file = ".".join(splits)
|
||||
if os.path.exists(index_file):
|
||||
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
||||
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
|
||||
index_file = Path(index_file)
|
||||
else:
|
||||
index_file = None
|
||||
else:
|
||||
if variant is not None:
|
||||
index_file_in_repo = Path(
|
||||
subfolder or "",
|
||||
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
|
||||
).as_posix()
|
||||
splits = index_file_in_repo.split(".")
|
||||
split_index = -2
|
||||
splits = splits[:-split_index] + [variant] + splits[-split_index:]
|
||||
index_file_in_repo = ".".join(splits)
|
||||
try:
|
||||
index_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=index_file_in_repo,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=None,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
index_file = Path(index_file)
|
||||
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
||||
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
|
||||
except (EntryNotFoundError, EnvironmentError):
|
||||
index_file = None
|
||||
|
||||
return index_file
|
||||
|
||||
@@ -54,7 +54,6 @@ from ..utils.hub_utils import (
|
||||
from .model_loading_utils import (
|
||||
_determine_device_map,
|
||||
_fetch_index_file,
|
||||
_fetch_index_file_legacy,
|
||||
_load_state_dict_into_model,
|
||||
load_model_dict_into_meta,
|
||||
load_state_dict,
|
||||
@@ -310,9 +309,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
weight_name_split = weights_name.split(".")
|
||||
if len(weight_name_split) in [2, 3]:
|
||||
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
|
||||
else:
|
||||
raise ValueError(f"Invalid {weights_name} provided.")
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
@@ -623,26 +624,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
is_sharded = False
|
||||
index_file = None
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
index_file_kwargs = {
|
||||
"is_local": is_local,
|
||||
"pretrained_model_name_or_path": pretrained_model_name_or_path,
|
||||
"subfolder": subfolder or "",
|
||||
"use_safetensors": use_safetensors,
|
||||
"cache_dir": cache_dir,
|
||||
"variant": variant,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"user_agent": user_agent,
|
||||
"commit_hash": commit_hash,
|
||||
}
|
||||
index_file = _fetch_index_file(**index_file_kwargs)
|
||||
# In case the index file was not found we still have to consider the legacy format.
|
||||
# this becomes applicable when the variant is not None.
|
||||
if variant is not None and (index_file is None or not os.path.exists(index_file)):
|
||||
index_file = _fetch_index_file_legacy(**index_file_kwargs)
|
||||
index_file = _fetch_index_file(
|
||||
is_local=is_local,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder or "",
|
||||
use_safetensors=use_safetensors,
|
||||
cache_dir=cache_dir,
|
||||
variant=variant,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
if index_file is not None and index_file.is_file():
|
||||
is_sharded = True
|
||||
|
||||
|
||||
@@ -355,51 +355,6 @@ class LuminaLayerNormContinuous(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
|
||||
self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
c_shift_msa,
|
||||
c_scale_msa,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
c_gate_mlp,
|
||||
) = emb.chunk(12, dim=1)
|
||||
normed_x = self.norm_x(x)
|
||||
normed_context = self.norm_c(context)
|
||||
x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
|
||||
|
||||
|
||||
class CogVideoXLayerNormZero(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -14,7 +14,6 @@ if is_torch_available():
|
||||
from .stable_audio_transformer import StableAudioDiTModel
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
@@ -1,386 +0,0 @@
|
||||
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
CogVideoXAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CogView3PlusTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
|
||||
|
||||
Args:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
time_embed_dim (`int`):
|
||||
The number of channels in timestep embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2560,
|
||||
num_attention_heads: int = 64,
|
||||
attention_head_dim: int = 40,
|
||||
time_embed_dim: int = 512,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
qk_norm="layer_norm",
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
(
|
||||
norm_hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
norm_encoder_hidden_states,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
c_gate_mlp,
|
||||
) = self.norm1(hidden_states, encoder_hidden_states, emb)
|
||||
|
||||
# attention
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# feed-forward
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
if encoder_hidden_states.dtype == torch.float16:
|
||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
|
||||
Diffusion](https://huggingface.co/papers/2403.05121).
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
attention_head_dim (`int`, defaults to `40`):
|
||||
The number of channels in each head.
|
||||
num_attention_heads (`int`, defaults to `64`):
|
||||
The number of heads to use for multi-head attention.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
condition_dim (`int`, defaults to `256`):
|
||||
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
|
||||
crop_coords).
|
||||
pos_embed_max_size (`int`, defaults to `128`):
|
||||
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
|
||||
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
|
||||
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
|
||||
patch_size => 128 * 8 * 2 => 2048`.
|
||||
sample_size (`int`, defaults to `128`):
|
||||
The base resolution of input latents. If height/width is not provided during generation, this value is used
|
||||
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
num_layers: int = 30,
|
||||
attention_head_dim: int = 40,
|
||||
num_attention_heads: int = 64,
|
||||
out_channels: int = 16,
|
||||
text_embed_dim: int = 4096,
|
||||
time_embed_dim: int = 512,
|
||||
condition_dim: int = 256,
|
||||
pos_embed_max_size: int = 128,
|
||||
sample_size: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
|
||||
# Each of these are sincos embeddings of shape 2 * condition_dim
|
||||
self.pooled_projection_dim = 3 * 2 * condition_dim
|
||||
|
||||
self.patch_embed = CogView3PlusPatchEmbed(
|
||||
in_channels=in_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
patch_size=patch_size,
|
||||
text_hidden_size=text_embed_dim,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
)
|
||||
|
||||
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
|
||||
embedding_dim=time_embed_dim,
|
||||
condition_dim=condition_dim,
|
||||
pooled_projection_dim=self.pooled_projection_dim,
|
||||
timesteps_dim=self.inner_dim,
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
CogView3PlusTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(
|
||||
embedding_dim=self.inner_dim,
|
||||
conditioning_embedding_dim=time_embed_dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
original_size: torch.Tensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`CogView3PlusTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input `hidden_states` of shape `(batch size, channel, height, width)`.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
|
||||
`(batch_size, sequence_len, text_embed_dim)`
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
original_size (`torch.Tensor`):
|
||||
CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
target_size (`torch.Tensor`):
|
||||
CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
crop_coords (`torch.Tensor`):
|
||||
CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
|
||||
The denoised latents using provided inputs as conditioning.
|
||||
"""
|
||||
height, width = hidden_states.shape[-2:]
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
|
||||
hidden_states = self.patch_embed(
|
||||
hidden_states, encoder_hidden_states
|
||||
) # takes care of adding positional embeddings too.
|
||||
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
emb=emb,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, emb)
|
||||
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
|
||||
|
||||
# unpatchify
|
||||
patch_size = self.config.patch_size
|
||||
height = height // patch_size
|
||||
width = width // patch_size
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -83,16 +83,14 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
image_rotary_emb=None,
|
||||
joint_attention_kwargs=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
@@ -163,20 +161,18 @@ class FluxTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
image_rotary_emb=None,
|
||||
joint_attention_kwargs=None,
|
||||
):
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
@@ -402,7 +398,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
return_dict: bool = True,
|
||||
controlnet_blocks_repeat: bool = False,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
@@ -502,20 +497,13 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
# For Xlabs ControlNet.
|
||||
if controlnet_blocks_repeat:
|
||||
hidden_states = (
|
||||
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
@@ -545,7 +533,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
|
||||
@@ -19,7 +19,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils.import_utils import is_torch_version
|
||||
from .normalization import RMSNorm
|
||||
|
||||
|
||||
@@ -152,10 +151,11 @@ class Upsample2D(nn.Module):
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
|
||||
# https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
@@ -170,8 +170,8 @@ class Upsample2D(nn.Module):
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# Cast back to original dtype
|
||||
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
|
||||
@@ -144,9 +144,7 @@ else:
|
||||
"CogVideoXPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
"CogVideoXVideoToVideoPipeline",
|
||||
"CogVideoXFunControlPipeline",
|
||||
]
|
||||
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
@@ -160,13 +158,11 @@ else:
|
||||
)
|
||||
_import_structure["pag"].extend(
|
||||
[
|
||||
"StableDiffusionControlNetPAGInpaintPipeline",
|
||||
"AnimateDiffPAGPipeline",
|
||||
"KolorsPAGPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"StableDiffusion3PAGPipeline",
|
||||
"StableDiffusionPAGPipeline",
|
||||
"StableDiffusionPAGImg2ImgPipeline",
|
||||
"StableDiffusionControlNetPAGPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
"StableDiffusionXLPAGInpaintPipeline",
|
||||
@@ -471,13 +467,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .cogvideo import (
|
||||
CogVideoXFunControlPipeline,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
CogVideoXPipeline,
|
||||
CogVideoXVideoToVideoPipeline,
|
||||
)
|
||||
from .cogview3 import CogView3PlusPipeline
|
||||
from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline
|
||||
from .controlnet import (
|
||||
BlipDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
@@ -576,9 +566,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KolorsPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGInpaintPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGImg2ImgPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
|
||||
@@ -20,7 +20,6 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import is_sentencepiece_available
|
||||
from .aura_flow import AuraFlowPipeline
|
||||
from .cogview3 import CogView3PlusPipeline
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
@@ -62,9 +61,7 @@ from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGInpaintPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGImg2ImgPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
@@ -120,7 +117,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("flux", FluxPipeline),
|
||||
("flux-controlnet", FluxControlNetPipeline),
|
||||
("lumina", LuminaText2ImgPipeline),
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -134,7 +130,6 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
|
||||
("kandinsky3", Kandinsky3Img2ImgPipeline),
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
|
||||
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
|
||||
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
|
||||
@@ -153,7 +148,6 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("kandinsky", KandinskyInpaintCombinedPipeline),
|
||||
("kandinsky22", KandinskyV22InpaintCombinedPipeline),
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
|
||||
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
|
||||
("flux", FluxInpaintPipeline),
|
||||
|
||||
@@ -23,7 +23,6 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
|
||||
_import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"]
|
||||
_import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
|
||||
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
|
||||
|
||||
@@ -36,7 +35,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_cogvideox import CogVideoXPipeline
|
||||
from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
|
||||
from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
||||
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
|
||||
|
||||
|
||||
@@ -1,794 +0,0 @@
|
||||
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI, Alibaba-PAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import CogVideoXLoraLoaderMixin
|
||||
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||
from ...models.embeddings import get_3d_rotary_pos_embed
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from .pipeline_output import CogVideoXPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogVideoXFunControlPipeline, DDIMScheduler
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> pipe = CogVideoXFunControlPipeline.from_pretrained(
|
||||
... "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> control_video = load_video(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
|
||||
... )
|
||||
>>> prompt = (
|
||||
... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and "
|
||||
... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in "
|
||||
... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, "
|
||||
... "moons, but the remainder of the scene is mostly realistic."
|
||||
... )
|
||||
|
||||
>>> video = pipe(prompt=prompt, control_video=control_video).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for controlled text-to-video generation using CogVideoX Fun.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. CogVideoX uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CogVideoXTransformer3DModel`]):
|
||||
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->vae->transformer->vae"
|
||||
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
self.vae_scale_factor_spatial = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
|
||||
)
|
||||
self.vae_scaling_factor_image = (
|
||||
self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
|
||||
)
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Adapted from https://github.com/aigc-apps/CogVideoX-Fun/blob/2a93e5c14e02b2b5921d533fd59fc8c0ed69fb24/cogvideox/pipeline/pipeline_cogvideox_control.py#L366
|
||||
def prepare_control_latents(
|
||||
self, mask: Optional[torch.Tensor] = None, masked_image: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if mask is not None:
|
||||
masks = []
|
||||
for i in range(mask.size(0)):
|
||||
current_mask = mask[i].unsqueeze(0)
|
||||
current_mask = self.vae.encode(current_mask)[0]
|
||||
current_mask = current_mask.mode()
|
||||
masks.append(current_mask)
|
||||
mask = torch.cat(masks, dim=0)
|
||||
mask = mask * self.vae.config.scaling_factor
|
||||
|
||||
if masked_image is not None:
|
||||
mask_pixel_values = []
|
||||
for i in range(masked_image.size(0)):
|
||||
mask_pixel_value = masked_image[i].unsqueeze(0)
|
||||
mask_pixel_value = self.vae.encode(mask_pixel_value)[0]
|
||||
mask_pixel_value = mask_pixel_value.mode()
|
||||
mask_pixel_values.append(mask_pixel_value)
|
||||
masked_image_latents = torch.cat(mask_pixel_values, dim=0)
|
||||
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
|
||||
else:
|
||||
masked_image_latents = None
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
|
||||
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
latents = 1 / self.vae_scaling_factor_image * latents
|
||||
|
||||
frames = self.vae.decode(latents).sample
|
||||
return frames
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
control_video=None,
|
||||
control_video_latents=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if control_video is not None and control_video_latents is not None:
|
||||
raise ValueError(
|
||||
"Cannot pass both `control_video` and `control_video_latents`. Please make sure to pass only one of these parameters."
|
||||
)
|
||||
|
||||
def fuse_qkv_projections(self) -> None:
|
||||
r"""Enables fused QKV projections."""
|
||||
self.fusing_transformer = True
|
||||
self.transformer.fuse_qkv_projections()
|
||||
|
||||
def unfuse_qkv_projections(self) -> None:
|
||||
r"""Disable QKV projection fusion if enabled."""
|
||||
if not self.fusing_transformer:
|
||||
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.transformer.unfuse_qkv_projections()
|
||||
self.fusing_transformer = False
|
||||
|
||||
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
|
||||
def _prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
control_video: Optional[List[Image.Image]] = None,
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 6,
|
||||
use_dynamic_cfg: bool = False,
|
||||
num_videos_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
control_video_latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 226,
|
||||
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
control_video (`List[PIL.Image.Image]`):
|
||||
The control video to condition the generation on. Must be a list of images/frames of the video. If not
|
||||
provided, `control_video_latents` must be provided.
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
||||
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 6.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
control_video_latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for
|
||||
controlled video generation. If not provided, `control_video` must be provided.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `226`):
|
||||
Maximum sequence length in encoded prompt. Must be consistent with
|
||||
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
num_videos_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
control_video,
|
||||
control_video_latents,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if control_video is not None and isinstance(control_video[0], Image.Image):
|
||||
control_video = [control_video]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels // 2
|
||||
num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if control_video_latents is None:
|
||||
control_video = self.video_processor.preprocess_video(control_video, height=height, width=width)
|
||||
control_video = control_video.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
_, control_video_latents = self.prepare_control_latents(None, control_video)
|
||||
control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
latent_control_input = (
|
||||
torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
|
||||
)
|
||||
latent_model_input = torch.cat([latent_model_input, latent_control_input], dim=2)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
self._guidance_scale = 1 + guidance_scale * (
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return CogVideoXPipelineOutput(frames=video)
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
@@ -23,7 +23,6 @@ from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import CogVideoXLoraLoaderMixin
|
||||
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||
from ...models.embeddings import get_3d_rotary_pos_embed
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
@@ -153,7 +152,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
class CogVideoXImageToVideoPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for image-to-video generation using CogVideoX.
|
||||
|
||||
@@ -547,10 +546,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
@@ -577,7 +572,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
@@ -641,10 +635,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
@@ -689,7 +679,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -779,7 +768,6 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_cogview3plus"] = ["CogView3PlusPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_cogview3plus import CogView3PlusPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,674 +0,0 @@
|
||||
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from ...utils import logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import CogView3PipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import CogView3PlusPipeline
|
||||
|
||||
>>> pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3B", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A photo of an astronaut riding a horse on mars"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
>>> image.save("output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class CogView3PlusPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using CogView3Plus.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. CogView3Plus uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CogView3PlusTransformer2DModel`]):
|
||||
A text conditioned `CogView3PlusTransformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKL,
|
||||
transformer: CogView3PlusTransformer2DModel,
|
||||
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds with num_videos_per_prompt->num_images_per_prompt
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 224,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
max_sequence_length (`int`, defaults to `224`):
|
||||
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt is None:
|
||||
negative_prompt_embeds = prompt_embeds.new_zeros(prompt_embeds.shape)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 224,
|
||||
) -> Union[CogView3PipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. If not provided, it is set to 1024.
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. If not provided it is set to 1024.
|
||||
num_inference_steps (`int`, *optional*, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to `1`):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
||||
explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
||||
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
||||
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, defaults to `224`):
|
||||
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] or `tuple`:
|
||||
[`~pipelines.cogview3.pipeline_cogview3plus.CogView3PipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = (height, width)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
latent_channels,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare additional timestep conditions
|
||||
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype)
|
||||
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
|
||||
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
original_size = torch.cat([original_size, original_size])
|
||||
target_size = torch.cat([target_size, target_size])
|
||||
crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
|
||||
|
||||
original_size = original_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
target_size = target_size.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
crops_coords_top_left = crops_coords_top_left.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# for DPM-solver++
|
||||
old_pred_original_sample = None
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
else:
|
||||
latents, old_pred_original_sample = self.scheduler.step(
|
||||
noise_pred,
|
||||
old_pred_original_sample,
|
||||
t,
|
||||
timesteps[i - 1] if i > 0 else None,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
else:
|
||||
image = latents
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return CogView3PipelineOutput(images=image)
|
||||
@@ -1,21 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class CogView3PipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for CogView3 pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -893,10 +893,6 @@ class StableDiffusionControlNetPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1093,7 +1089,6 @@ class StableDiffusionControlNetPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1240,9 +1235,6 @@ class StableDiffusionControlNetPipeline(
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# Relevant thread:
|
||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
||||
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
||||
|
||||
@@ -891,10 +891,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1085,7 +1081,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1216,9 +1211,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -976,10 +976,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1195,7 +1191,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1380,9 +1375,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -1145,10 +1145,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1431,7 +1427,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1700,9 +1695,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
|
||||
@@ -990,10 +990,6 @@ class StableDiffusionXLControlNetPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1249,7 +1245,6 @@ class StableDiffusionXLControlNetPipeline(
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1447,9 +1442,6 @@ class StableDiffusionXLControlNetPipeline(
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# Relevant thread:
|
||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
||||
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
||||
|
||||
@@ -1070,10 +1070,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -1342,7 +1338,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -1515,9 +1510,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -225,8 +225,6 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = HunyuanDiT2DMultiControlNetModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
|
||||
@@ -192,8 +192,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
],
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = SD3MultiControlNetModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
|
||||
+2
-11
@@ -251,9 +251,6 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
self.patch_size = (
|
||||
self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
@@ -580,14 +577,8 @@ class StableDiffusion3ControlNetInpaintingPipeline(DiffusionPipeline, SD3LoraLoa
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if (
|
||||
height % (self.vae_scale_factor * self.patch_size) != 0
|
||||
or width % (self.vae_scale_factor * self.patch_size) != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
|
||||
f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
|
||||
)
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
|
||||
@@ -9,17 +9,16 @@ from ...utils import BaseOutput
|
||||
|
||||
@dataclass
|
||||
class IFPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
"""
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`):
|
||||
Output class for Stable Diffusion pipelines.
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
nsfw_detected (`List[bool]`):
|
||||
nsfw_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content or a watermark. `None` if safety checking could not be performed.
|
||||
watermark_detected (`List[bool]`):
|
||||
watermark_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
|
||||
checking could not be performed.
|
||||
"""
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -137,12 +137,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class FluxPipeline(
|
||||
DiffusionPipeline,
|
||||
FluxLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
):
|
||||
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
The Flux pipeline for text-to-image generation.
|
||||
|
||||
@@ -217,9 +212,6 @@ class FluxPipeline(
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -263,9 +255,6 @@ class FluxPipeline(
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
|
||||
@@ -25,7 +25,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
@@ -72,9 +72,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... control_image=control_image,
|
||||
... control_guidance_start=0.2,
|
||||
... control_guidance_end=0.8,
|
||||
... controlnet_conditioning_scale=1.0,
|
||||
... controlnet_conditioning_scale=0.6,
|
||||
... num_inference_steps=28,
|
||||
... guidance_scale=3.5,
|
||||
... ).images[0]
|
||||
@@ -202,8 +200,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
],
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = FluxMultiControlNetModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -238,9 +234,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -284,9 +277,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -582,8 +572,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
control_image: PipelineImageInput = None,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
@@ -626,10 +614,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the ControlNet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the ControlNet stops applying.
|
||||
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
@@ -690,17 +674,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
@@ -760,35 +733,29 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
# xlab controlnet has a input_hint_block and instantx controlnet does not
|
||||
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
|
||||
if self.controlnet.input_hint_block is None:
|
||||
# vae encode
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
# vae encode
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
# set control mode
|
||||
if control_mode is not None:
|
||||
if not isinstance(control_mode, int):
|
||||
raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
|
||||
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
|
||||
control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
|
||||
control_mode = control_mode.reshape([-1, 1])
|
||||
|
||||
elif isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
control_images = []
|
||||
# xlab controlnet has a input_hint_block and instantx controlnet does not
|
||||
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
|
||||
for i, control_image_ in enumerate(control_image):
|
||||
|
||||
for control_image_ in control_image:
|
||||
control_image_ = self.prepare_image(
|
||||
image=control_image_,
|
||||
width=width,
|
||||
@@ -800,40 +767,34 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
)
|
||||
height, width = control_image_.shape[-2:]
|
||||
|
||||
if self.controlnet.nets[0].input_hint_block is None:
|
||||
# vae encode
|
||||
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
# vae encode
|
||||
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
|
||||
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image_.shape[2:]
|
||||
control_image_ = self._pack_latents(
|
||||
control_image_,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
control_images.append(control_image_)
|
||||
|
||||
control_image = control_images
|
||||
|
||||
# Here we ensure that `control_mode` has the same length as the control_image.
|
||||
if isinstance(control_mode, list) and len(control_mode) != len(control_image):
|
||||
raise ValueError(
|
||||
"For Multi-ControlNet, `control_mode` must be a list of the same "
|
||||
+ " length as the number of controlnets (control images) specified"
|
||||
)
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode] * len(control_image)
|
||||
# set control mode
|
||||
control_modes = []
|
||||
for cmode in control_mode:
|
||||
if cmode is None:
|
||||
cmode = -1
|
||||
control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
|
||||
control_modes.append(control_mode)
|
||||
control_mode = control_modes
|
||||
control_mode_ = []
|
||||
if isinstance(control_mode, list):
|
||||
for cmode in control_mode:
|
||||
if cmode is None:
|
||||
control_mode_.append(-1)
|
||||
else:
|
||||
control_mode_.append(cmode)
|
||||
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
|
||||
control_mode = control_mode.reshape([-1, 1])
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
@@ -870,16 +831,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
|
||||
|
||||
# 7. Denoising loop
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
@@ -888,28 +840,17 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if isinstance(self.controlnet, FluxMultiControlNetModel):
|
||||
use_guidance = self.controlnet.nets[0].config.guidance_embeds
|
||||
else:
|
||||
use_guidance = self.controlnet.config.guidance_embeds
|
||||
|
||||
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
|
||||
guidance = (
|
||||
torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
|
||||
)
|
||||
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
else:
|
||||
controlnet_cond_scale = controlnet_conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
# controlnet
|
||||
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
||||
hidden_states=latents,
|
||||
controlnet_cond=control_image,
|
||||
controlnet_mode=control_mode,
|
||||
conditioning_scale=cond_scale,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
@@ -937,7 +878,6 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
controlnet_blocks_repeat=controlnet_blocks_repeat,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
@@ -69,9 +69,7 @@ EXAMPLE_DOC_STRING = """
|
||||
... prompt,
|
||||
... image=init_image,
|
||||
... control_image=control_image,
|
||||
... control_guidance_start=0.2,
|
||||
... control_guidance_end=0.8,
|
||||
... controlnet_conditioning_scale=1.0,
|
||||
... controlnet_conditioning_scale=0.6,
|
||||
... strength=0.7,
|
||||
... num_inference_steps=2,
|
||||
... guidance_scale=3.5,
|
||||
@@ -214,8 +212,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
],
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = FluxMultiControlNetModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -251,9 +247,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -298,9 +291,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -641,8 +631,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
@@ -722,17 +710,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
@@ -885,14 +862,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
latents,
|
||||
)
|
||||
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
@@ -908,19 +877,11 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
)
|
||||
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
else:
|
||||
controlnet_cond_scale = controlnet_conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
||||
hidden_states=latents,
|
||||
controlnet_cond=control_image,
|
||||
controlnet_mode=control_mode,
|
||||
conditioning_scale=cond_scale,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
@@ -71,8 +71,6 @@ EXAMPLE_DOC_STRING = """
|
||||
... image=init_image,
|
||||
... mask_image=mask_image,
|
||||
... control_image=control_image,
|
||||
... control_guidance_start=0.2,
|
||||
... control_guidance_end=0.8,
|
||||
... controlnet_conditioning_scale=0.7,
|
||||
... strength=0.7,
|
||||
... num_inference_steps=28,
|
||||
@@ -216,8 +214,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
],
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = FluxMultiControlNetModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
scheduler=scheduler,
|
||||
@@ -261,9 +257,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -308,9 +301,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -747,8 +737,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
timesteps: List[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
guidance_scale: float = 7.0,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
control_mode: Optional[Union[int, List[int]]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
@@ -795,10 +783,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
Custom timesteps to use for the denoising process.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the ControlNet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the ControlNet stops applying.
|
||||
control_mode (`int` or `List[int]`, *optional*):
|
||||
The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
@@ -842,17 +826,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
global_height = height
|
||||
global_width = width
|
||||
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
@@ -1058,14 +1031,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
generator,
|
||||
)
|
||||
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
@@ -1084,19 +1049,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
else:
|
||||
controlnet_cond_scale = controlnet_conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
||||
hidden_states=latents,
|
||||
controlnet_cond=control_image,
|
||||
controlnet_mode=control_mode,
|
||||
conditioning_scale=cond_scale,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -235,9 +235,6 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -282,9 +279,6 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -239,9 +239,6 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -286,9 +283,6 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
|
||||
@@ -277,7 +277,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
padding_side: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
@@ -299,9 +298,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta).
|
||||
padding_side (`str`, *optional*):
|
||||
The side on which the model should have padding applied. Should be selected between ['right', 'left'].
|
||||
Default value is picked from the class attribute of the same name.
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
|
||||
@@ -23,7 +23,6 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_inpaint"] = ["StableDiffusionControlNetPAGInpaintPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetPAGImg2ImgPipeline"]
|
||||
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
|
||||
@@ -32,7 +31,6 @@ else:
|
||||
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"]
|
||||
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
|
||||
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
|
||||
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
|
||||
@@ -46,7 +44,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
|
||||
from .pipeline_pag_controlnet_sd_inpaint import StableDiffusionControlNetPAGInpaintPipeline
|
||||
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
|
||||
from .pipeline_pag_controlnet_sd_xl_img2img import StableDiffusionXLControlNetPAGImg2ImgPipeline
|
||||
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
|
||||
@@ -55,7 +52,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_pag_sd import StableDiffusionPAGPipeline
|
||||
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
|
||||
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
|
||||
from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline
|
||||
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
|
||||
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
|
||||
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
|
||||
|
||||
@@ -98,9 +98,7 @@ class PAGMixin:
|
||||
else:
|
||||
return self.pag_scale
|
||||
|
||||
def _apply_perturbed_attention_guidance(
|
||||
self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
|
||||
):
|
||||
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
|
||||
r"""
|
||||
Apply perturbed attention guidance to the noise prediction.
|
||||
|
||||
@@ -109,11 +107,9 @@ class PAGMixin:
|
||||
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
|
||||
guidance_scale (float): The scale factor for the guidance term.
|
||||
t (int): The current time step.
|
||||
return_pred_text (bool): Whether to return the text noise prediction.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
|
||||
perturbed attention guidance and the text noise prediction.
|
||||
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
|
||||
"""
|
||||
pag_scale = self._get_pag_scale(t)
|
||||
if do_classifier_free_guidance:
|
||||
@@ -126,8 +122,6 @@ class PAGMixin:
|
||||
else:
|
||||
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
|
||||
if return_pred_text:
|
||||
return noise_pred, noise_pred_text
|
||||
return noise_pred
|
||||
|
||||
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user