Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4133545a15 | |||
| 37b8edfb86 | |||
| d34dbbd05a | |||
| fbf6b856cc | |||
| e031caf4ea | |||
| 08f74a8b92 | |||
| 24c062aaa1 | |||
| a74f02fb40 | |||
| 66bf7ea5be | |||
| b8215b1c06 | |||
| 3ee899fa0c | |||
| dcd77ce222 | |||
| 11d8e3ce2c | |||
| 97fda1b75c | |||
| cc22058324 | |||
| 7855ac597e | |||
| 30cef6bff3 | |||
| 8f15be169f | |||
| f92e599c70 | |||
| 982f9b38d6 | |||
| c9a219b323 | |||
| 9e910c4633 | |||
| 5e3b7d2d8a | |||
| 7513162b8b | |||
| 4aaa0d21ba | |||
| 54043c3e2e | |||
| fc4229a0c3 | |||
| 694f9658c1 | |||
| 2d8a41cae8 | |||
| 7007febae5 | |||
| d230ecc570 | |||
| 37a5f1b3b6 | |||
| 501d9de701 | |||
| e5c43b8af7 |
@@ -9,139 +9,43 @@ permissions:
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
run-style-bot:
|
||||
if: >
|
||||
contains(github.event.comment.body, '@bot /style') &&
|
||||
github.event.issue.pull_request != null
|
||||
runs-on: ubuntu-latest
|
||||
style:
|
||||
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
|
||||
with:
|
||||
python_quality_dependencies: "[quality]"
|
||||
pre_commit_script_name: "Download and Compare files from the main branch"
|
||||
pre_commit_script: |
|
||||
echo "Downloading the files from the main branch"
|
||||
|
||||
steps:
|
||||
- name: Extract PR details
|
||||
id: pr_info
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = context.payload.issue.number;
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber
|
||||
});
|
||||
|
||||
// We capture both the branch ref and the "full_name" of the head repo
|
||||
// so that we can check out the correct repository & branch (including forks).
|
||||
core.setOutput("prNumber", prNumber);
|
||||
core.setOutput("headRef", pr.head.ref);
|
||||
core.setOutput("headRepoFullName", pr.head.repo.full_name);
|
||||
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
|
||||
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
|
||||
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
|
||||
|
||||
- name: Check out PR branch
|
||||
uses: actions/checkout@v3
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
with:
|
||||
# Instead of checking out the base repo, use the contributor's repo name
|
||||
repository: ${{ env.HEADREPOFULLNAME }}
|
||||
ref: ${{ env.HEADREF }}
|
||||
# You may need fetch-depth: 0 for being able to push
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Debug
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
run: |
|
||||
echo "PR number: $PRNUMBER"
|
||||
echo "Head Ref: $HEADREF"
|
||||
echo "Head Repo Full Name: $HEADREPOFULLNAME"
|
||||
echo "Compare the files and raise error if needed"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
diff_failed=0
|
||||
if ! diff -q main_Makefile Makefile; then
|
||||
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install .[quality]
|
||||
if ! diff -q main_setup.py setup.py; then
|
||||
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
- name: Download necessary files from main branch of Diffusers
|
||||
run: |
|
||||
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
|
||||
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
|
||||
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
|
||||
|
||||
- name: Compare the files and raise error if needed
|
||||
run: |
|
||||
diff_failed=0
|
||||
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
|
||||
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if ! diff -q main_Makefile Makefile; then
|
||||
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
if [ $diff_failed -eq 1 ]; then
|
||||
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! diff -q main_setup.py setup.py; then
|
||||
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
|
||||
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if [ $diff_failed -eq 1 ]; then
|
||||
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "No changes in the files. Proceeding..."
|
||||
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
|
||||
|
||||
- name: Run make style and make quality
|
||||
run: |
|
||||
make style && make quality
|
||||
|
||||
- name: Commit and push changes
|
||||
id: commit_and_push
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
echo "HEADREPOFULLNAME: $HEADREPOFULLNAME, HEADREF: $HEADREF"
|
||||
# Configure git with the Actions bot user
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Make sure your 'origin' remote is set to the contributor's fork
|
||||
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/$HEADREPOFULLNAME.git"
|
||||
|
||||
# If there are changes after running style/quality, commit them
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
git add .
|
||||
git commit -m "Apply style fixes"
|
||||
# Push to the original contributor's forked branch
|
||||
git push origin HEAD:$HEADREF
|
||||
echo "changes_pushed=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No changes to commit."
|
||||
echo "changes_pushed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Comment on PR with workflow run link
|
||||
if: steps.commit_and_push.outputs.changes_pushed == 'true'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = parseInt(process.env.prNumber, 10);
|
||||
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
|
||||
});
|
||||
env:
|
||||
prNumber: ${{ steps.pr_info.outputs.prNumber }}
|
||||
echo "No changes in the files. Proceeding..."
|
||||
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
|
||||
style_command: "make style && make quality"
|
||||
secrets:
|
||||
bot_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -11,6 +11,8 @@ on:
|
||||
- "src/diffusers/loaders/lora_base.py"
|
||||
- "src/diffusers/loaders/lora_pipeline.py"
|
||||
- "src/diffusers/loaders/peft.py"
|
||||
- "tests/pipelines/test_pipelines_common.py"
|
||||
- "tests/models/test_modeling_common.py"
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
@@ -104,11 +106,18 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and $pattern" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
else
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and $pattern" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
fi
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -76,6 +76,14 @@
|
||||
- local: advanced_inference/outpaint
|
||||
title: Outpainting
|
||||
title: Advanced inference
|
||||
- sections:
|
||||
- local: hybrid_inference/overview
|
||||
title: Overview
|
||||
- local: hybrid_inference/vae_decode
|
||||
title: VAE Decode
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
title: Hybrid Inference
|
||||
- sections:
|
||||
- local: using-diffusers/cogvideox
|
||||
title: CogVideoX
|
||||
@@ -282,6 +290,8 @@
|
||||
title: CogView4Transformer2DModel
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/easyanimate_transformer3d
|
||||
title: EasyAnimateTransformer3DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
@@ -314,6 +324,8 @@
|
||||
title: Transformer2DModel
|
||||
- local: api/models/transformer_temporal
|
||||
title: TransformerTemporalModel
|
||||
- local: api/models/wan_transformer_3d
|
||||
title: WanTransformer3DModel
|
||||
title: Transformers
|
||||
- sections:
|
||||
- local: api/models/stable_cascade_unet
|
||||
@@ -342,8 +354,12 @@
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
title: AutoencoderKLLTXVideo
|
||||
- local: api/models/autoencoderkl_magvit
|
||||
title: AutoencoderKLMagvit
|
||||
- local: api/models/autoencoderkl_mochi
|
||||
title: AutoencoderKLMochi
|
||||
- local: api/models/autoencoder_kl_wan
|
||||
title: AutoencoderKLWan
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_dc
|
||||
@@ -418,6 +434,8 @@
|
||||
title: DiffEdit
|
||||
- local: api/pipelines/dit
|
||||
title: DiT
|
||||
- local: api/pipelines/easyanimate
|
||||
title: EasyAnimate
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
@@ -534,6 +552,8 @@
|
||||
title: UniDiffuser
|
||||
- local: api/pipelines/value_guided_sampling
|
||||
title: Value-guided sampling
|
||||
- local: api/pipelines/wan
|
||||
title: Wan
|
||||
- local: api/pipelines/wuerstchen
|
||||
title: Wuerstchen
|
||||
title: Pipelines
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLWan
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
|
||||
```
|
||||
|
||||
## AutoencoderKLWan
|
||||
|
||||
[[autodoc]] AutoencoderKLWan
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -0,0 +1,37 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLMagvit
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLMagvit
|
||||
|
||||
vae = AutoencoderKLMagvit.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="vae", torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLMagvit
|
||||
|
||||
[[autodoc]] AutoencoderKLMagvit
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -0,0 +1,30 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# EasyAnimateTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import EasyAnimateTransformer3DModel
|
||||
|
||||
transformer = EasyAnimateTransformer3DModel.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
## EasyAnimateTransformer3DModel
|
||||
|
||||
[[autodoc]] EasyAnimateTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# WanTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data was introduced in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import WanTransformer3DModel
|
||||
|
||||
transformer = WanTransformer3DModel.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## WanTransformer3DModel
|
||||
|
||||
[[autodoc]] WanTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,88 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# EasyAnimate
|
||||
[EasyAnimate](https://github.com/aigc-apps/EasyAnimate) by Alibaba PAI.
|
||||
|
||||
The description from it's GitHub page:
|
||||
*EasyAnimate is a pipeline based on the transformer architecture, designed for generating AI images and videos, and for training baseline models and Lora models for Diffusion Transformer. We support direct prediction from pre-trained EasyAnimate models, allowing for the generation of videos with various resolutions, approximately 6 seconds in length, at 8fps (EasyAnimateV5.1, 1 to 49 frames). Additionally, users can train their own baseline and Lora models for specific style transformations.*
|
||||
|
||||
This pipeline was contributed by [bubbliiiing](https://github.com/bubbliiiing). The original codebase can be found [here](https://huggingface.co/alibaba-pai). The original weights can be found under [hf.co/alibaba-pai](https://huggingface.co/alibaba-pai).
|
||||
|
||||
There are two official EasyAnimate checkpoints for text-to-video and video-to-video.
|
||||
|
||||
| checkpoints | recommended inference dtype |
|
||||
|:---:|:---:|
|
||||
| [`alibaba-pai/EasyAnimateV5.1-12b-zh`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh) | torch.float16 |
|
||||
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
|
||||
|
||||
There is one official EasyAnimate checkpoints available for image-to-video and video-to-video.
|
||||
|
||||
| checkpoints | recommended inference dtype |
|
||||
|:---:|:---:|
|
||||
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
|
||||
|
||||
There are two official EasyAnimate checkpoints available for control-to-video.
|
||||
|
||||
| checkpoints | recommended inference dtype |
|
||||
|:---:|:---:|
|
||||
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control) | torch.float16 |
|
||||
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera) | torch.float16 |
|
||||
|
||||
For the EasyAnimateV5.1 series:
|
||||
- Text-to-video (T2V) and Image-to-video (I2V) works for multiple resolutions. The width and height can vary from 256 to 1024.
|
||||
- Both T2V and I2V models support generation with 1~49 frames and work best at this value. Exporting videos at 8 FPS is recommended.
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`EasyAnimatePipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = EasyAnimateTransformer3DModel.from_pretrained(
|
||||
"alibaba-pai/EasyAnimateV5.1-12b-zh",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = EasyAnimatePipeline.from_pretrained(
|
||||
"alibaba-pai/EasyAnimateV5.1-12b-zh",
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "A cat walks on the grass, realistic style."
|
||||
negative_prompt = "bad detailed"
|
||||
video = pipeline(prompt=prompt, negative_prompt=negative_prompt, num_frames=49, num_inference_steps=30).frames[0]
|
||||
export_to_video(video, "cat.mp4", fps=8)
|
||||
```
|
||||
|
||||
## EasyAnimatePipeline
|
||||
|
||||
[[autodoc]] EasyAnimatePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## EasyAnimatePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.easyanimate.pipeline_output.EasyAnimatePipelineOutput
|
||||
@@ -0,0 +1,62 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Wan
|
||||
|
||||
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
|
||||
|
||||
<!-- TODO(aryan): update abstract once paper is out -->
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
Recommendations for inference:
|
||||
- VAE in `torch.float32` for better decoding quality.
|
||||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
|
||||
|
||||
### Using a custom scheduler
|
||||
|
||||
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
|
||||
|
||||
```python
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, WanPipeline
|
||||
|
||||
scheduler_a = FlowMatchEulerDiscreteScheduler(shift=5.0)
|
||||
scheduler_b = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=4.0)
|
||||
|
||||
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler=<CUSTOM_SCHEDULER_HERE>)
|
||||
|
||||
# or,
|
||||
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
|
||||
```
|
||||
|
||||
## WanPipeline
|
||||
|
||||
[[autodoc]] WanPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanImageToVideoPipeline
|
||||
|
||||
[[autodoc]] WanImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
|
||||
@@ -16,6 +16,11 @@ specific language governing permissions and limitations under the License.
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||
> [!TIP]
|
||||
> This document has now grown outdated given the emergence of existing evaluation frameworks for diffusion models for image generation. Please check
|
||||
> out works like [HEIM](https://crfm.stanford.edu/helm/heim/latest/), [T2I-Compbench](https://arxiv.org/abs/2307.06350),
|
||||
> [GenEval](https://arxiv.org/abs/2310.11513).
|
||||
|
||||
Evaluation of generative models like [Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion) is subjective in nature. But as practitioners and researchers, we often have to make careful choices amongst many different possibilities. So, when working with different generative models (like GANs, Diffusion, etc.), how do we choose one over the other?
|
||||
|
||||
Qualitative evaluation of such models can be error-prone and might incorrectly influence a decision.
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
# Hybrid Inference API Reference
|
||||
|
||||
## Remote Decode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_decode
|
||||
@@ -0,0 +1,54 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Hybrid Inference
|
||||
|
||||
**Empowering local AI builders with Hybrid Inference**
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae).
|
||||
> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
|
||||
## Why use Hybrid Inference?
|
||||
|
||||
Hybrid Inference offers a fast and simple way to offload local generation requirements.
|
||||
|
||||
- 🚀 **Reduced Requirements:** Access powerful models without expensive hardware.
|
||||
- 💎 **Without Compromise:** Achieve the highest quality without sacrificing performance.
|
||||
- 💰 **Cost Effective:** It's free! 🤑
|
||||
- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community.
|
||||
- 🔧 **Developer-Friendly:** Simple requests, fast responses.
|
||||
|
||||
---
|
||||
|
||||
## Available Models
|
||||
|
||||
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
|
||||
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
|
||||
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
|
||||
|
||||
---
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
|
||||
## Contents
|
||||
|
||||
The documentation is organized into two sections:
|
||||
|
||||
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
|
||||
* **API Reference** Dive into task-specific settings and parameters.
|
||||
@@ -0,0 +1,345 @@
|
||||
# Getting Started: VAE Decode with Hybrid Inference
|
||||
|
||||
VAE decode is an essential component of diffusion models - turning latent representations into images or videos.
|
||||
|
||||
## Memory
|
||||
|
||||
These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs.
|
||||
|
||||
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality.
|
||||
|
||||
<details><summary>SD v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
|
||||
|
||||
</details>
|
||||
|
||||
## Available VAEs
|
||||
|
||||
| | **Endpoint** | **Model** |
|
||||
|:-:|:-----------:|:--------:|
|
||||
| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
|
||||
| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
|
||||
| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
|
||||
| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
## Code
|
||||
|
||||
> [!TIP]
|
||||
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
|
||||
|
||||
|
||||
A helper method simplifies interacting with Hybrid Inference.
|
||||
|
||||
```python
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
```
|
||||
|
||||
### Basic example
|
||||
|
||||
Here, we show how to use the remote VAE on random tensors.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/output.png"/>
|
||||
</figure>
|
||||
|
||||
Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
image = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
|
||||
height=1024,
|
||||
width=1024,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/flux_random_latent.png"/>
|
||||
</figure>
|
||||
|
||||
Finally, an example for HunyuanVideo.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
video = remote_decode(
|
||||
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16),
|
||||
output_type="mp4",
|
||||
)
|
||||
with open("video.mp4", "wb") as f:
|
||||
f.write(video)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video_1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
|
||||
### Generation
|
||||
|
||||
But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
|
||||
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
image.save("test.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test.jpg"/>
|
||||
</figure>
|
||||
|
||||
Here’s another example with Flux.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16,
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
|
||||
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
guidance_scale=0.0,
|
||||
num_inference_steps=4,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
height=1024,
|
||||
width=1024,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
image.save("test.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test_1.jpg"/>
|
||||
</figure>
|
||||
|
||||
Here’s an example with HunyuanVideo.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
|
||||
model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
latent = pipe(
|
||||
prompt="A cat walks on the grass, realistic",
|
||||
height=320,
|
||||
width=512,
|
||||
num_frames=61,
|
||||
num_inference_steps=30,
|
||||
output_type="latent",
|
||||
).frames
|
||||
|
||||
video = remote_decode(
|
||||
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
output_type="mp4",
|
||||
)
|
||||
|
||||
if isinstance(video, bytes):
|
||||
with open("video.mp4", "wb") as f:
|
||||
f.write(video)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
|
||||
### Queueing
|
||||
|
||||
One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency.
|
||||
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
import queue
|
||||
import threading
|
||||
from IPython.display import display
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
def decode_worker(q: queue.Queue):
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
break
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=item,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
display(image)
|
||||
q.task_done()
|
||||
|
||||
q = queue.Queue()
|
||||
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
def decode(latent: torch.Tensor):
|
||||
q.put(latent)
|
||||
|
||||
prompts = [
|
||||
"Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious",
|
||||
"Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore",
|
||||
"Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.",
|
||||
"Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP",
|
||||
"A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting",
|
||||
"Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,",
|
||||
]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"Lykon/dreamshaper-8",
|
||||
torch_dtype=torch.float16,
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
_ = pipe(
|
||||
prompt=prompts[0],
|
||||
output_type="latent",
|
||||
)
|
||||
|
||||
for prompt in prompts:
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
output_type="latent",
|
||||
).images
|
||||
decode(latent)
|
||||
|
||||
q.put(None)
|
||||
thread.join()
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/queue.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
@@ -157,6 +157,84 @@ pipeline(
|
||||
)
|
||||
```
|
||||
|
||||
## IP Adapter Cutoff
|
||||
|
||||
IP Adapter is an image prompt adapter that can be used for diffusion models without any changes to the underlying model. We can use the IP Adapter Cutoff Callback to disable the IP Adapter after a certain number of steps. To set up the callback, you need to specify the number of denoising steps after which the callback comes into effect. You can do so by using either one of these two arguments:
|
||||
|
||||
- `cutoff_step_ratio`: Float number with the ratio of the steps.
|
||||
- `cutoff_step_index`: Integer number with the exact number of the step.
|
||||
|
||||
We need to download the diffusion model and load the ip_adapter for it as follows:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
pipeline.set_ip_adapter_scale(0.6)
|
||||
```
|
||||
The setup for the callback should look something like this:
|
||||
|
||||
```py
|
||||
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.callbacks import IPAdapterScaleCutoffCallback
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
|
||||
pipeline.load_ip_adapter(
|
||||
"h94/IP-Adapter",
|
||||
subfolder="sdxl_models",
|
||||
weight_name="ip-adapter_sdxl.bin"
|
||||
)
|
||||
|
||||
pipeline.set_ip_adapter_scale(0.6)
|
||||
|
||||
|
||||
callback = IPAdapterScaleCutoffCallback(
|
||||
cutoff_step_ratio=None,
|
||||
cutoff_step_index=5
|
||||
)
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png"
|
||||
)
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(2628670641)
|
||||
|
||||
images = pipeline(
|
||||
prompt="a tiger sitting in a chair drinking orange juice",
|
||||
ip_adapter_image=image,
|
||||
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
|
||||
generator=generator,
|
||||
num_inference_steps=50,
|
||||
callback_on_step_end=callback,
|
||||
).images
|
||||
|
||||
images[0].save("custom_callback_img.png")
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/without_callback.png" alt="generated image of a tiger sitting in a chair drinking orange juice" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">without IPAdapterScaleCutoffCallback</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_callback2.png" alt="generated image of a tiger sitting in a chair drinking orange juice with ip adapter callback" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">with IPAdapterScaleCutoffCallback</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
## Display image after each generation step
|
||||
|
||||
> [!TIP]
|
||||
|
||||
@@ -227,7 +227,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
with autocast_ctx:
|
||||
@@ -880,9 +880,7 @@ class TokenEmbeddingsHandler:
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
|
||||
embeds = (
|
||||
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
|
||||
)
|
||||
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
|
||||
new_token_embeddings = embeds.weight.data[train_ids]
|
||||
|
||||
@@ -904,9 +902,7 @@ class TokenEmbeddingsHandler:
|
||||
@torch.no_grad()
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
embeds = (
|
||||
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
|
||||
)
|
||||
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
embeds.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
@@ -1749,7 +1745,7 @@ def main(args):
|
||||
if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
|
||||
text_lora_parameters_two = []
|
||||
for name, param in text_encoder_two.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
if "shared" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
|
||||
@@ -1883,7 +1883,11 @@ def main(args):
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = (
|
||||
torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
if args.seed is not None
|
||||
else None
|
||||
)
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -1987,7 +1991,9 @@ def main(args):
|
||||
)
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = (
|
||||
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
)
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -269,7 +269,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
|
||||
@@ -722,7 +722,7 @@ def log_validation(
|
||||
# pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
|
||||
videos = []
|
||||
for _ in range(args.num_validation_videos):
|
||||
|
||||
@@ -739,7 +739,7 @@ def log_validation(
|
||||
# pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
|
||||
videos = []
|
||||
for _ in range(args.num_validation_videos):
|
||||
|
||||
@@ -53,6 +53,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Stable Diffusion Mixture Tiling Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SD 1.5](#stable-diffusion-mixture-tiling-pipeline-sd-15) | [](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
|
||||
| Stable Diffusion Mixture Canvas Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending. Works by defining a list of Text2Image region objects that detail the region of influence of each diffuser. | [Stable Diffusion Mixture Canvas Pipeline SD 1.5](#stable-diffusion-mixture-canvas-pipeline-sd-15) | [](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
|
||||
| Stable Diffusion Mixture Tiling Pipeline SDXL | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SDXL](#stable-diffusion-mixture-tiling-pipeline-sdxl) | [](https://huggingface.co/spaces/elismasilva/mixture-of-diffusers-sdxl-tiling) | [Eliseu Silva](https://github.com/DEVAIEXP/) |
|
||||
| Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL | This is an advanced pipeline that leverages ControlNet Tile and Mixture-of-Diffusers techniques, integrating tile diffusion directly into the latent space denoising process. Designed to overcome the limitations of conventional pixel-space tile processing, this pipeline delivers Super Resolution (SR) upscaling for higher-quality images, reduced processing time, and greater adaptability. | [Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL](#stable-diffusion-mod-controlnet-tile-sr-pipeline-sdxl) | [](https://huggingface.co/spaces/elismasilva/mod-control-tile-upscaler-sdxl) | [Eliseu Silva](https://github.com/DEVAIEXP/) |
|
||||
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
@@ -2630,6 +2631,103 @@ image = pipe(
|
||||
|
||||

|
||||
|
||||
### Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL
|
||||
|
||||
This pipeline implements the [MoD (Mixture-of-Diffusers)]("https://arxiv.org/pdf/2408.06072") tiled diffusion technique and combines it with SDXL's ControlNet Tile process to generate SR images.
|
||||
|
||||
This works better with 4x scales, but you can try adjusts parameters to higher scales.
|
||||
|
||||
````python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, ControlNetUnionModel, AutoencoderKL, UniPCMultistepScheduler, UNet2DConditionModel
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
|
||||
device = "cuda"
|
||||
|
||||
# Initialize the models and pipeline
|
||||
controlnet = ControlNetUnionModel.from_pretrained(
|
||||
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
|
||||
).to(device=device)
|
||||
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device=device)
|
||||
|
||||
model_id = "SG161222/RealVisXL_V5.0"
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.float16,
|
||||
vae=vae,
|
||||
controlnet=controlnet,
|
||||
custom_pipeline="mod_controlnet_tile_sr_sdxl",
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
|
||||
|
||||
#pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
|
||||
pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
|
||||
pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
|
||||
|
||||
# Set selected scheduler
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# Load image
|
||||
control_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/1.jpg")
|
||||
original_height = control_image.height
|
||||
original_width = control_image.width
|
||||
print(f"Current resolution: H:{original_height} x W:{original_width}")
|
||||
|
||||
# Pre-upscale image for tiling
|
||||
resolution = 4096
|
||||
tile_gaussian_sigma = 0.3
|
||||
max_tile_size = 1024 # or 1280
|
||||
|
||||
current_size = max(control_image.size)
|
||||
scale_factor = max(2, resolution / current_size)
|
||||
new_size = (int(control_image.width * scale_factor), int(control_image.height * scale_factor))
|
||||
image = control_image.resize(new_size, Image.LANCZOS)
|
||||
|
||||
# Update target height and width
|
||||
target_height = image.height
|
||||
target_width = image.width
|
||||
print(f"Target resolution: H:{target_height} x W:{target_width}")
|
||||
|
||||
# Calculate overlap size
|
||||
normal_tile_overlap, border_tile_overlap = pipe.calculate_overlap(target_width, target_height)
|
||||
|
||||
# Set other params
|
||||
tile_weighting_method = pipe.TileWeightingMethod.COSINE.value
|
||||
guidance_scale = 4
|
||||
num_inference_steps = 35
|
||||
denoising_strenght = 0.65
|
||||
controlnet_strength = 1.0
|
||||
prompt = "high-quality, noise-free edges, high quality, 4k, hd, 8k"
|
||||
negative_prompt = "blurry, pixelated, noisy, low resolution, artifacts, poor details"
|
||||
|
||||
# Image generation
|
||||
generated_image = pipe(
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
control_mode=[6],
|
||||
controlnet_conditioning_scale=float(controlnet_strength),
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
normal_tile_overlap=normal_tile_overlap,
|
||||
border_tile_overlap=border_tile_overlap,
|
||||
height=target_height,
|
||||
width=target_width,
|
||||
original_size=(original_width, original_height),
|
||||
target_size=(target_width, target_height),
|
||||
guidance_scale=guidance_scale,
|
||||
strength=float(denoising_strenght),
|
||||
tile_weighting_method=tile_weighting_method,
|
||||
max_tile_size=max_tile_size,
|
||||
tile_gaussian_sigma=float(tile_gaussian_sigma),
|
||||
num_inference_steps=num_inference_steps,
|
||||
)["images"][0]
|
||||
````
|
||||

|
||||
|
||||
### TensorRT Inpainting Stable Diffusion Pipeline
|
||||
|
||||
The TensorRT Pipeline can be used to accelerate the Inpainting Stable Diffusion Inference run.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1334,7 +1334,9 @@ def main(args):
|
||||
|
||||
# run inference
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = (
|
||||
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
)
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -172,7 +172,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
|
||||
if args.validation_images is None:
|
||||
images = []
|
||||
@@ -1119,17 +1119,22 @@ def main(args):
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1146,8 +1151,15 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
|
||||
with autocast_ctx:
|
||||
|
||||
@@ -170,7 +170,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
|
||||
|
||||
@@ -199,7 +199,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
|
||||
@@ -207,7 +207,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
|
||||
@@ -175,7 +175,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
|
||||
@@ -137,7 +137,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
@@ -1241,7 +1241,11 @@ def main(args):
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = (
|
||||
torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
if args.seed is not None
|
||||
else None
|
||||
)
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
with autocast_ctx:
|
||||
@@ -1305,7 +1309,9 @@ def main(args):
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
generator = (
|
||||
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
)
|
||||
|
||||
with autocast_ctx:
|
||||
images = [
|
||||
|
||||
@@ -0,0 +1,423 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
UniPCMultistepScheduler,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
||||
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
||||
"time_projection.1": "condition_embedder.time_proj",
|
||||
"head.modulation": "scale_shift_table",
|
||||
"head.head": "proj_out",
|
||||
"modulation": "scale_shift_table",
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
# For the I2V model
|
||||
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def load_sharded_safetensors(dir: pathlib.Path):
|
||||
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
|
||||
state_dict = {}
|
||||
for path in file_paths:
|
||||
state_dict.update(load_file(path))
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
if model_type == "Wan-T2V-1.3B":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 12,
|
||||
"num_layers": 30,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
elif model_type == "Wan-T2V-14B":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
elif model_type == "Wan-I2V-14B-480p":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
|
||||
"diffusers_config": {
|
||||
"image_dim": 1280,
|
||||
"added_kv_proj_dim": 5120,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
elif model_type == "Wan-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
|
||||
"diffusers_config": {
|
||||
"image_dim": 1280,
|
||||
"added_kv_proj_dim": 5120,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def convert_transformer(model_type: str):
|
||||
config = get_transformer_config(model_type)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
model_id = config["model_id"]
|
||||
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
|
||||
|
||||
original_state_dict = load_sharded_safetensors(model_dir)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae():
|
||||
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
|
||||
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
|
||||
new_state_dict = {}
|
||||
|
||||
# Create mappings for specific components
|
||||
middle_key_mapping = {
|
||||
# Encoder middle block
|
||||
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
|
||||
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
|
||||
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
|
||||
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
|
||||
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
|
||||
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
|
||||
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
|
||||
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
|
||||
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
|
||||
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
|
||||
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
|
||||
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
|
||||
# Decoder middle block
|
||||
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
|
||||
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
|
||||
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
|
||||
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
|
||||
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
|
||||
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
|
||||
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
|
||||
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
|
||||
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
|
||||
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
|
||||
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
|
||||
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
|
||||
}
|
||||
|
||||
# Create a mapping for attention blocks
|
||||
attention_mapping = {
|
||||
# Encoder middle attention
|
||||
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
|
||||
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
|
||||
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
|
||||
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
|
||||
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
|
||||
# Decoder middle attention
|
||||
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
|
||||
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
|
||||
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
|
||||
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
|
||||
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
|
||||
}
|
||||
|
||||
# Create a mapping for the head components
|
||||
head_mapping = {
|
||||
# Encoder head
|
||||
"encoder.head.0.gamma": "encoder.norm_out.gamma",
|
||||
"encoder.head.2.bias": "encoder.conv_out.bias",
|
||||
"encoder.head.2.weight": "encoder.conv_out.weight",
|
||||
# Decoder head
|
||||
"decoder.head.0.gamma": "decoder.norm_out.gamma",
|
||||
"decoder.head.2.bias": "decoder.conv_out.bias",
|
||||
"decoder.head.2.weight": "decoder.conv_out.weight",
|
||||
}
|
||||
|
||||
# Create a mapping for the quant components
|
||||
quant_mapping = {
|
||||
"conv1.weight": "quant_conv.weight",
|
||||
"conv1.bias": "quant_conv.bias",
|
||||
"conv2.weight": "post_quant_conv.weight",
|
||||
"conv2.bias": "post_quant_conv.bias",
|
||||
}
|
||||
|
||||
# Process each key in the state dict
|
||||
for key, value in old_state_dict.items():
|
||||
# Handle middle block keys using the mapping
|
||||
if key in middle_key_mapping:
|
||||
new_key = middle_key_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle attention blocks using the mapping
|
||||
elif key in attention_mapping:
|
||||
new_key = attention_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle head keys using the mapping
|
||||
elif key in head_mapping:
|
||||
new_key = head_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle quant keys using the mapping
|
||||
elif key in quant_mapping:
|
||||
new_key = quant_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle encoder conv1
|
||||
elif key == "encoder.conv1.weight":
|
||||
new_state_dict["encoder.conv_in.weight"] = value
|
||||
elif key == "encoder.conv1.bias":
|
||||
new_state_dict["encoder.conv_in.bias"] = value
|
||||
# Handle decoder conv1
|
||||
elif key == "decoder.conv1.weight":
|
||||
new_state_dict["decoder.conv_in.weight"] = value
|
||||
elif key == "decoder.conv1.bias":
|
||||
new_state_dict["decoder.conv_in.bias"] = value
|
||||
# Handle encoder downsamples
|
||||
elif key.startswith("encoder.downsamples."):
|
||||
# Convert to down_blocks
|
||||
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
|
||||
|
||||
# Convert residual block naming but keep the original structure
|
||||
if ".residual.0.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
||||
elif ".residual.2.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
||||
elif ".residual.2.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
||||
elif ".residual.3.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
||||
elif ".residual.6.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
||||
elif ".residual.6.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
||||
elif ".shortcut.bias" in new_key:
|
||||
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
||||
elif ".shortcut.weight" in new_key:
|
||||
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
# Handle decoder upsamples
|
||||
elif key.startswith("decoder.upsamples."):
|
||||
# Convert to up_blocks
|
||||
parts = key.split(".")
|
||||
block_idx = int(parts[2])
|
||||
|
||||
# Group residual blocks
|
||||
if "residual" in key:
|
||||
if block_idx in [0, 1, 2]:
|
||||
new_block_idx = 0
|
||||
resnet_idx = block_idx
|
||||
elif block_idx in [4, 5, 6]:
|
||||
new_block_idx = 1
|
||||
resnet_idx = block_idx - 4
|
||||
elif block_idx in [8, 9, 10]:
|
||||
new_block_idx = 2
|
||||
resnet_idx = block_idx - 8
|
||||
elif block_idx in [12, 13, 14]:
|
||||
new_block_idx = 3
|
||||
resnet_idx = block_idx - 12
|
||||
else:
|
||||
# Keep as is for other blocks
|
||||
new_state_dict[key] = value
|
||||
continue
|
||||
|
||||
# Convert residual block naming
|
||||
if ".residual.0.gamma" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
|
||||
elif ".residual.2.bias" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
|
||||
elif ".residual.2.weight" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
|
||||
elif ".residual.3.gamma" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
|
||||
elif ".residual.6.bias" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
|
||||
elif ".residual.6.weight" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
# Handle shortcut connections
|
||||
elif ".shortcut." in key:
|
||||
if block_idx == 4:
|
||||
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
|
||||
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
# Handle upsamplers
|
||||
elif ".resample." in key or ".time_conv." in key:
|
||||
if block_idx == 3:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
|
||||
elif block_idx == 7:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
|
||||
elif block_idx == 11:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
new_state_dict[new_key] = value
|
||||
else:
|
||||
# Keep other keys unchanged
|
||||
new_state_dict[key] = value
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLWan()
|
||||
vae.load_state_dict(new_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--dtype", default="fp32")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
transformer = convert_transformer(args.model_type).to(dtype=dtype)
|
||||
vae = convert_vae()
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
|
||||
)
|
||||
|
||||
if "I2V" in args.model_type:
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
)
|
||||
image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||
pipe = WanImageToVideoPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
else:
|
||||
pipe = WanPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
@@ -94,8 +94,10 @@ else:
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderKLWan",
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"CacheMixin",
|
||||
@@ -108,6 +110,7 @@ else:
|
||||
"ControlNetUnionModel",
|
||||
"ControlNetXSAdapter",
|
||||
"DiTTransformer2DModel",
|
||||
"EasyAnimateTransformer3DModel",
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
@@ -148,6 +151,7 @@ else:
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"UVit2DModel",
|
||||
"VQModel",
|
||||
"WanTransformer3DModel",
|
||||
]
|
||||
)
|
||||
_import_structure["optimization"] = [
|
||||
@@ -291,6 +295,9 @@ else:
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
@@ -438,6 +445,8 @@ else:
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
"VideoToVideoSDPipeline",
|
||||
"VQDiffusionPipeline",
|
||||
"WanImageToVideoPipeline",
|
||||
"WanPipeline",
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
@@ -616,8 +625,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
CacheMixin,
|
||||
@@ -630,6 +641,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
DiTTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
@@ -669,6 +681,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
VQModel,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
@@ -792,6 +805,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
@@ -938,6 +954,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
VQDiffusionPipeline,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
|
||||
@@ -74,6 +74,7 @@ if is_torch_available():
|
||||
"HunyuanVideoLoraLoaderMixin",
|
||||
"SanaLoraLoaderMixin",
|
||||
"Lumina2LoraLoaderMixin",
|
||||
"WanLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
@@ -112,6 +113,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SD3LoraLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
WanLoraLoaderMixin,
|
||||
)
|
||||
from .single_file import FromSingleFileMixin
|
||||
from .textual_inversion import TextualInversionLoaderMixin
|
||||
|
||||
@@ -215,7 +215,8 @@ class IPAdapterMixin:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
torch_dtype=self.dtype,
|
||||
).to(self.device)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -526,8 +527,9 @@ class FluxIPAdapterMixin:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
dtype=image_encoder_dtype,
|
||||
)
|
||||
.to(self.device, dtype=image_encoder_dtype)
|
||||
.to(self.device)
|
||||
.eval()
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
@@ -805,9 +807,9 @@ class SD3IPAdapterMixin:
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(
|
||||
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
||||
).to(self.device),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -1276,3 +1276,74 @@ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
|
||||
# Remove "diffusion_model." prefix from keys.
|
||||
state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
||||
converted_state_dict = {}
|
||||
|
||||
def get_num_layers(keys, pattern):
|
||||
layers = set()
|
||||
for key in keys:
|
||||
match = re.search(pattern, key)
|
||||
if match:
|
||||
layers.add(int(match.group(1)))
|
||||
return len(layers)
|
||||
|
||||
def process_block(prefix, index, convert_norm):
|
||||
# Process attention qkv: pop lora_A and lora_B weights.
|
||||
lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
|
||||
lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
|
||||
for attn_key in ["to_q", "to_k", "to_v"]:
|
||||
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
|
||||
for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
|
||||
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
|
||||
|
||||
# Process attention out weights.
|
||||
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.attention.out.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.attention.out.lora_B.weight"
|
||||
)
|
||||
|
||||
# Process feed-forward weights for layers 1, 2, and 3.
|
||||
for layer in range(1, 4):
|
||||
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
|
||||
)
|
||||
|
||||
if convert_norm:
|
||||
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
|
||||
)
|
||||
|
||||
noise_refiner_pattern = r"noise_refiner\.(\d+)\."
|
||||
num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
|
||||
for i in range(num_noise_refiner_layers):
|
||||
process_block("noise_refiner", i, convert_norm=True)
|
||||
|
||||
context_refiner_pattern = r"context_refiner\.(\d+)\."
|
||||
num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
|
||||
for i in range(num_context_refiner_layers):
|
||||
process_block("context_refiner", i, convert_norm=False)
|
||||
|
||||
core_transformer_pattern = r"layers\.(\d+)\."
|
||||
num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
|
||||
for i in range(num_core_transformer_layers):
|
||||
process_block("layers", i, convert_norm=True)
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -41,6 +41,7 @@ from .lora_conversion_utils import (
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
)
|
||||
@@ -3815,7 +3816,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -3909,6 +3909,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
# conversion.
|
||||
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if non_diffusers:
|
||||
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
@@ -4110,6 +4115,311 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
||||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`WanTransformer3DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -53,6 +53,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1448,8 +1448,8 @@ def convert_open_clip_checkpoint(
|
||||
|
||||
if text_proj_key in checkpoint:
|
||||
text_proj_dim = int(checkpoint[text_proj_key].shape[0])
|
||||
elif hasattr(text_model.config, "projection_dim"):
|
||||
text_proj_dim = text_model.config.projection_dim
|
||||
elif hasattr(text_model.config, "hidden_size"):
|
||||
text_proj_dim = text_model.config.hidden_size
|
||||
else:
|
||||
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
|
||||
|
||||
|
||||
Regular → Executable
+8
@@ -33,8 +33,10 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
|
||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
@@ -71,6 +73,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
@@ -79,6 +82,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -107,8 +111,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
@@ -141,6 +147,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsisIDTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
FluxTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
@@ -158,6 +165,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from .unets import (
|
||||
I2VGenXLUNet,
|
||||
|
||||
Regular → Executable
+8
-1
@@ -274,12 +274,19 @@ class Attention(nn.Module):
|
||||
self.to_add_out = None
|
||||
|
||||
if qk_norm is not None and added_kv_proj_dim is not None:
|
||||
if qk_norm == "fp32_layer_norm":
|
||||
if qk_norm == "layer_norm":
|
||||
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
elif qk_norm == "fp32_layer_norm":
|
||||
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
elif qk_norm == "rms_norm":
|
||||
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
||||
elif qk_norm == "rms_norm_across_heads":
|
||||
# Wanx applies qk norm across all heads
|
||||
self.norm_added_q = RMSNorm(dim_head * heads, eps=eps)
|
||||
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
|
||||
|
||||
@@ -5,8 +5,10 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_kl_wan import AutoencoderKLWan
|
||||
from .autoencoder_oobleck import AutoencoderOobleck
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,865 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
class WanCausalConv3d(nn.Conv3d):
|
||||
r"""
|
||||
A custom 3D causal convolution layer with feature caching support.
|
||||
|
||||
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
||||
caching for efficient inference.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
||||
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
# Set up causal padding
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class WanRMS_norm(nn.Module):
|
||||
r"""
|
||||
A custom RMS normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The number of dimensions to normalize over.
|
||||
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
||||
Default is True.
|
||||
images (bool, optional): Whether the input represents image data. Default is True.
|
||||
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class WanUpsample(nn.Upsample):
|
||||
r"""
|
||||
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor to be upsampled.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Upsampled tensor with the same data type as the input.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class WanResample(nn.Module):
|
||||
r"""
|
||||
A custom resampling module for 2D and 3D data.
|
||||
|
||||
Args:
|
||||
dim (int): The number of input/output channels.
|
||||
mode (str): The resampling mode. Must be one of:
|
||||
- 'none': No resampling (identity operation).
|
||||
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
|
||||
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
|
||||
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
||||
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == "upsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
elif mode == "upsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
elif mode == "downsample2d":
|
||||
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == "downsample3d":
|
||||
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == "upsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = "Rep"
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat(
|
||||
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
||||
)
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
||||
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
||||
if feat_cache[idx] == "Rep":
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.resample(x)
|
||||
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
||||
|
||||
if self.mode == "downsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
|
||||
class WanResidualBlock(nn.Module):
|
||||
r"""
|
||||
A custom residual block module.
|
||||
|
||||
Args:
|
||||
in_dim (int): Number of input channels.
|
||||
out_dim (int): Number of output channels.
|
||||
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
||||
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
dropout: float = 0.0,
|
||||
non_linearity: str = "silu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
# layers
|
||||
self.norm1 = WanRMS_norm(in_dim, images=False)
|
||||
self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
|
||||
self.norm2 = WanRMS_norm(out_dim, images=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Apply shortcut connection
|
||||
h = self.conv_shortcut(x)
|
||||
|
||||
# First normalization and activation
|
||||
x = self.norm1(x)
|
||||
x = self.nonlinearity(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
# Second normalization and activation
|
||||
x = self.norm2(x)
|
||||
x = self.nonlinearity(x)
|
||||
|
||||
# Dropout
|
||||
x = self.dropout(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.conv2(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv2(x)
|
||||
|
||||
# Add residual connection
|
||||
return x + h
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
r"""
|
||||
Causal self-attention with a single head.
|
||||
|
||||
Args:
|
||||
dim (int): The number of channels in the input tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = WanRMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
batch_size, channels, time, height, width = x.size()
|
||||
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
||||
x = self.norm(x)
|
||||
|
||||
# compute query, key, value
|
||||
qkv = self.to_qkv(x)
|
||||
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
||||
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
||||
|
||||
# output projection
|
||||
x = self.proj(x)
|
||||
|
||||
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
|
||||
x = x.view(batch_size, time, channels, height, width)
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return x + identity
|
||||
|
||||
|
||||
class WanMidBlock(nn.Module):
|
||||
"""
|
||||
Middle block for WanVAE encoder and decoder.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input/output channels.
|
||||
dropout (float): Dropout rate.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# Create the components
|
||||
resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
|
||||
attentions = []
|
||||
for _ in range(num_layers):
|
||||
attentions.append(WanAttentionBlock(dim))
|
||||
resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# First residual block
|
||||
x = self.resnets[0](x, feat_cache, feat_idx)
|
||||
|
||||
# Process through attention and residual blocks
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
x = attn(x)
|
||||
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanEncoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D encoder module.
|
||||
|
||||
Args:
|
||||
dim (int): The base number of channels in the first layer.
|
||||
z_dim (int): The dimensionality of the latent space.
|
||||
dim_mult (list of int): Multipliers for the number of channels in each block.
|
||||
num_res_blocks (int): Number of residual blocks in each block.
|
||||
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
||||
temperal_downsample (list of bool): Whether to downsample temporally in each block.
|
||||
dropout (float): Dropout rate for the dropout layers.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
self.down_blocks.append(WanAttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
||||
|
||||
# output blocks
|
||||
self.norm_out = WanRMS_norm(out_dim, images=False)
|
||||
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_in(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_in(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.down_blocks:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
x = self.nonlinearity(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_out(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanUpBlock(nn.Module):
|
||||
"""
|
||||
A block that handles upsampling for the WanVAE decoder.
|
||||
|
||||
Args:
|
||||
in_dim (int): Input dimension
|
||||
out_dim (int): Output dimension
|
||||
num_res_blocks (int): Number of residual blocks
|
||||
dropout (float): Dropout rate
|
||||
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
||||
non_linearity (str): Type of non-linearity to use
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
num_res_blocks: int,
|
||||
dropout: float = 0.0,
|
||||
upsample_mode: Optional[str] = None,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# Create layers list
|
||||
resnets = []
|
||||
# Add residual blocks and attention if needed
|
||||
current_dim = in_dim
|
||||
for _ in range(num_res_blocks + 1):
|
||||
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
||||
current_dim = out_dim
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
# Add upsampling layer if needed
|
||||
self.upsamplers = None
|
||||
if upsample_mode is not None:
|
||||
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
"""
|
||||
Forward pass through the upsampling block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
feat_cache (list, optional): Feature cache for causal convolutions
|
||||
feat_idx (list, optional): Feature index for cache management
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor
|
||||
"""
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = self.upsamplers[0](x)
|
||||
return x
|
||||
|
||||
|
||||
class WanDecoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D decoder module.
|
||||
|
||||
Args:
|
||||
dim (int): The base number of channels in the first layer.
|
||||
z_dim (int): The dimensionality of the latent space.
|
||||
dim_mult (list of int): Multipliers for the number of channels in each block.
|
||||
num_res_blocks (int): Number of residual blocks in each block.
|
||||
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
||||
temperal_upsample (list of bool): Whether to upsample temporally in each block.
|
||||
dropout (float): Dropout rate for the dropout layers.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
||||
|
||||
# upsample blocks
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i > 0:
|
||||
in_dim = in_dim // 2
|
||||
|
||||
# Determine if we need upsampling
|
||||
upsample_mode = None
|
||||
if i != len(dim_mult) - 1:
|
||||
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
||||
|
||||
# Create and add the upsampling block
|
||||
up_block = WanUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
upsample_mode=upsample_mode,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# Update scale for next iteration
|
||||
if upsample_mode is not None:
|
||||
scale *= 2.0
|
||||
|
||||
# output blocks
|
||||
self.norm_out = WanRMS_norm(out_dim, images=False)
|
||||
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_in(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_in(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = up_block(x, feat_cache, feat_idx)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
x = self.nonlinearity(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_out(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class AutoencoderKLWan(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
||||
Introduced in [Wan 2.1].
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
dropout: float = 0.0,
|
||||
latents_mean: List[float] = [
|
||||
-0.7571,
|
||||
-0.7089,
|
||||
-0.9113,
|
||||
0.1075,
|
||||
-0.1745,
|
||||
0.9653,
|
||||
-0.1517,
|
||||
1.5508,
|
||||
0.4134,
|
||||
-0.0715,
|
||||
0.5517,
|
||||
-0.3632,
|
||||
-0.1922,
|
||||
-0.9497,
|
||||
0.2503,
|
||||
-0.2921,
|
||||
],
|
||||
latents_std: List[float] = [
|
||||
2.8184,
|
||||
1.4541,
|
||||
2.3275,
|
||||
2.6558,
|
||||
1.2196,
|
||||
1.7708,
|
||||
2.6052,
|
||||
2.0743,
|
||||
3.2687,
|
||||
2.1526,
|
||||
2.8652,
|
||||
1.5579,
|
||||
1.6382,
|
||||
1.1253,
|
||||
2.8251,
|
||||
1.9160,
|
||||
],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Store normalization parameters as tensors
|
||||
self.mean = torch.tensor(latents_mean)
|
||||
self.std = torch.tensor(latents_std)
|
||||
self.scale = torch.stack([self.mean, 1.0 / self.std]) # Shape: [2, C]
|
||||
|
||||
self.z_dim = z_dim
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
self.encoder = WanEncoder3d(
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
||||
)
|
||||
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
||||
|
||||
self.decoder = WanDecoder3d(
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
||||
)
|
||||
|
||||
def clear_cache(self):
|
||||
def _count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, WanCausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
self._conv_num = _count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = _count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
scale = self.scale.type_as(x)
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
enc = self.quant_conv(out)
|
||||
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
||||
logvar = (logvar - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
||||
enc = torch.cat([mu, logvar], dim=1)
|
||||
self.clear_cache()
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
|
||||
|
||||
iter_ = z.shape[2]
|
||||
x = self.post_quant_conv(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
self.clear_cache()
|
||||
if not return_dict:
|
||||
return (out,)
|
||||
|
||||
return DecoderOutput(sample=out)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
scale = self.scale.type_as(z)
|
||||
decoded = self._decode(z, scale).sample
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, return_dict=return_dict)
|
||||
return dec
|
||||
@@ -605,12 +605,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
controlnet_cond: List[torch.Tensor],
|
||||
control_type: torch.Tensor,
|
||||
control_type_idx: List[int],
|
||||
conditioning_scale: float = 1.0,
|
||||
conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
from_multi: bool = False,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
@@ -647,6 +648,8 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
from_multi (`bool`, defaults to `False`):
|
||||
Use standard scaling when called from `MultiControlNetUnionModel`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
@@ -658,6 +661,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
if isinstance(conditioning_scale, float):
|
||||
conditioning_scale = [conditioning_scale] * len(controlnet_cond)
|
||||
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
@@ -742,12 +748,16 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
inputs = []
|
||||
condition_list = []
|
||||
|
||||
for cond, control_idx in zip(controlnet_cond, control_type_idx):
|
||||
for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
|
||||
condition = self.controlnet_cond_embedding(cond)
|
||||
feat_seq = torch.mean(condition, dim=(2, 3))
|
||||
feat_seq = feat_seq + self.task_embedding[control_idx]
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(condition)
|
||||
if from_multi:
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(condition)
|
||||
else:
|
||||
inputs.append(feat_seq.unsqueeze(1) * scale)
|
||||
condition_list.append(condition * scale)
|
||||
|
||||
condition = sample
|
||||
feat_seq = torch.mean(condition, dim=(2, 3))
|
||||
@@ -759,10 +769,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
x = layer(x)
|
||||
|
||||
controlnet_cond_fuser = sample * 0.0
|
||||
for idx, condition in enumerate(condition_list[:-1]):
|
||||
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
|
||||
alpha = self.spatial_ch_projs(x[:, idx])
|
||||
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
||||
controlnet_cond_fuser += condition + alpha
|
||||
if from_multi:
|
||||
controlnet_cond_fuser += condition + alpha
|
||||
else:
|
||||
controlnet_cond_fuser += condition + alpha * scale
|
||||
|
||||
sample = sample + controlnet_cond_fuser
|
||||
|
||||
@@ -806,12 +819,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
if from_multi:
|
||||
scales = scales * conditioning_scale[0]
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
elif from_multi:
|
||||
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
|
||||
@@ -47,9 +47,12 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
|
||||
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
|
||||
):
|
||||
if scale == 0.0:
|
||||
continue
|
||||
down_samples, mid_sample = controlnet(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
@@ -63,12 +66,13 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
attention_mask=attention_mask,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
from_multi=True,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if i == 0:
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
down_block_res_samples = [
|
||||
|
||||
@@ -166,8 +166,12 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
|
||||
# 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
|
||||
last_dtype = None
|
||||
for param in parameter.parameters():
|
||||
|
||||
for name, param in parameter.named_parameters():
|
||||
last_dtype = param.dtype
|
||||
if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
|
||||
continue
|
||||
|
||||
if param.is_floating_point():
|
||||
return param.dtype
|
||||
|
||||
|
||||
Regular → Executable
+2
@@ -19,6 +19,7 @@ if is_torch_available():
|
||||
from .transformer_allegro import AllegroTransformer3DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
@@ -27,3 +28,4 @@ if is_torch_available():
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
|
||||
@@ -0,0 +1,527 @@
|
||||
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class EasyAnimateLayerNormZero(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_dim: int,
|
||||
embedding_dim: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
norm_type: str = "fp32_layer_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
elif norm_type == "fp32_layer_norm":
|
||||
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale.unsqueeze(1)) + enc_shift.unsqueeze(
|
||||
1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states, gate, enc_gate
|
||||
|
||||
|
||||
class EasyAnimateRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, patch_size: int, rope_dim: List[int]) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.rope_dim = rope_dim
|
||||
|
||||
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
bs, c, num_frames, grid_height, grid_width = hidden_states.size()
|
||||
grid_height = grid_height // self.patch_size
|
||||
grid_width = grid_width // self.patch_size
|
||||
base_size_width = 90 // self.patch_size
|
||||
base_size_height = 60 // self.patch_size
|
||||
|
||||
grid_crops_coords = self.get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
image_rotary_emb = get_3d_rotary_pos_embed(
|
||||
self.rope_dim,
|
||||
grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=hidden_states.size(2),
|
||||
use_real=True,
|
||||
)
|
||||
return image_rotary_emb
|
||||
|
||||
|
||||
class EasyAnimateAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the EasyAnimateTransformer3DModel model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"EasyAnimateAttnProcessor2_0 requires PyTorch 2.0 or above. To use it, please install PyTorch 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# 3. Encoder condition QKV projection and normalization
|
||||
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=2)
|
||||
key = torch.cat([encoder_key, key], dim=2)
|
||||
value = torch.cat([encoder_value, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
|
||||
query[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
|
||||
)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
|
||||
key[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
|
||||
)
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# 6. Output projection
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
if getattr(attn, "to_out", None) is not None:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if getattr(attn, "to_add_out", None) is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
else:
|
||||
if getattr(attn, "to_out", None) is not None:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class EasyAnimateTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
time_embed_dim: int,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-6,
|
||||
final_dropout: bool = True,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
qk_norm: bool = True,
|
||||
after_norm: bool = False,
|
||||
norm_type: str = "fp32_layer_norm",
|
||||
is_mmdit_block: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Attention Part
|
||||
self.norm1 = EasyAnimateLayerNormZero(
|
||||
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
|
||||
)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
qk_norm="layer_norm" if qk_norm else None,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
added_proj_bias=True,
|
||||
added_kv_proj_dim=dim if is_mmdit_block else None,
|
||||
context_pre_only=False if is_mmdit_block else None,
|
||||
processor=EasyAnimateAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# FFN Part
|
||||
self.norm2 = EasyAnimateLayerNormZero(
|
||||
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
|
||||
)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
self.txt_ff = None
|
||||
if is_mmdit_block:
|
||||
self.txt_ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
self.norm3 = None
|
||||
if after_norm:
|
||||
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Attention
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
if self.norm3 is not None:
|
||||
norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
|
||||
if self.txt_ff is not None:
|
||||
norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
|
||||
else:
|
||||
norm_encoder_hidden_states = self.norm3(self.ff(norm_encoder_hidden_states))
|
||||
else:
|
||||
norm_hidden_states = self.ff(norm_hidden_states)
|
||||
if self.txt_ff is not None:
|
||||
norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
|
||||
else:
|
||||
norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states)
|
||||
hidden_states = hidden_states + gate_ff.unsqueeze(1) * norm_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff.unsqueeze(1) * norm_encoder_hidden_states
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, defaults to `48`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
sample_width (`int`, defaults to `90`):
|
||||
The width of the input latents.
|
||||
sample_height (`int`, defaults to `60`):
|
||||
The height of the input latents.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
||||
Activation function to use when generating the timestep embeddings.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
mmdit_layers (`int`, defaults to `1000`):
|
||||
The number of layers of Multi Modal Transformer blocks to use.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
The epsilon value to use in normalization layers.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use elementwise affine in normalization layers.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
time_position_encoding_type (`str`, defaults to `3d_rope`):
|
||||
Type of time position encoding.
|
||||
after_norm (`bool`, defaults to `False`):
|
||||
Flag to apply normalization after.
|
||||
resize_inpaint_mask_directly (`bool`, defaults to `True`):
|
||||
Flag to resize inpaint mask directly.
|
||||
enable_text_attention_mask (`bool`, defaults to `True`):
|
||||
Flag to enable text attention mask.
|
||||
add_noise_in_inpaint_model (`bool`, defaults to `False`):
|
||||
Flag to add noise in inpaint model.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["EasyAnimateTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["^proj$", "norm", "^proj_out$"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 48,
|
||||
attention_head_dim: int = 64,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
sample_width: int = 90,
|
||||
sample_height: int = 60,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
timestep_activation_fn: str = "silu",
|
||||
freq_shift: int = 0,
|
||||
num_layers: int = 48,
|
||||
mmdit_layers: int = 48,
|
||||
dropout: float = 0.0,
|
||||
time_embed_dim: int = 512,
|
||||
add_norm_text_encoder: bool = False,
|
||||
text_embed_dim: int = 3584,
|
||||
text_embed_dim_t5: int = None,
|
||||
norm_eps: float = 1e-5,
|
||||
norm_elementwise_affine: bool = True,
|
||||
flip_sin_to_cos: bool = True,
|
||||
time_position_encoding_type: str = "3d_rope",
|
||||
after_norm=False,
|
||||
resize_inpaint_mask_directly: bool = True,
|
||||
enable_text_attention_mask: bool = True,
|
||||
add_noise_in_inpaint_model: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Timestep embedding
|
||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||
self.rope_embedding = EasyAnimateRotaryPosEmbed(patch_size, attention_head_dim)
|
||||
|
||||
# 2. Patch embedding
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
|
||||
)
|
||||
|
||||
# 3. Text refined embedding
|
||||
self.text_proj = None
|
||||
self.text_proj_t5 = None
|
||||
if not add_norm_text_encoder:
|
||||
self.text_proj = nn.Linear(text_embed_dim, inner_dim)
|
||||
if text_embed_dim_t5 is not None:
|
||||
self.text_proj_t5 = nn.Linear(text_embed_dim_t5, inner_dim)
|
||||
else:
|
||||
self.text_proj = nn.Sequential(
|
||||
RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim, inner_dim)
|
||||
)
|
||||
if text_embed_dim_t5 is not None:
|
||||
self.text_proj_t5 = nn.Sequential(
|
||||
RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim_t5, inner_dim)
|
||||
)
|
||||
|
||||
# 4. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
EasyAnimateTransformerBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
after_norm=after_norm,
|
||||
is_mmdit_block=True if _ < mmdit_layers else False,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
# 5. Output norm & projection
|
||||
self.norm_out = AdaLayerNorm(
|
||||
embedding_dim=time_embed_dim,
|
||||
output_dim=2 * inner_dim,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
chunk_dim=1,
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_t5: Optional[torch.Tensor] = None,
|
||||
inpaint_latents: Optional[torch.Tensor] = None,
|
||||
control_latents: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
batch_size, channels, video_length, height, width = hidden_states.size()
|
||||
p = self.config.patch_size
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
|
||||
# 1. Time embedding
|
||||
temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
|
||||
temb = self.time_embedding(temb, timestep_cond)
|
||||
image_rotary_emb = self.rope_embedding(hidden_states)
|
||||
|
||||
# 2. Patch embedding
|
||||
if inpaint_latents is not None:
|
||||
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
|
||||
if control_latents is not None:
|
||||
hidden_states = torch.concat([hidden_states, control_latents], 1)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, F, H, W] -> [BF, C, H, W]
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
|
||||
0, 2, 1, 3, 4
|
||||
) # [BF, C, H, W] -> [B, F, C, H, W]
|
||||
hidden_states = hidden_states.flatten(2, 4).transpose(1, 2) # [B, F, C, H, W] -> [B, FHW, C]
|
||||
|
||||
# 3. Text embedding
|
||||
encoder_hidden_states = self.text_proj(encoder_hidden_states)
|
||||
if encoder_hidden_states_t5 is not None:
|
||||
encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous()
|
||||
|
||||
# 4. Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, image_rotary_emb
|
||||
)
|
||||
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
|
||||
# 5. Output norm & projection
|
||||
hidden_states = self.norm_out(hidden_states, temb=temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 6. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, video_length, post_patch_height, post_patch_width, channels, p, p)
|
||||
output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -0,0 +1,459 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :257]
|
||||
encoder_hidden_states = encoder_hidden_states[:, 257:]
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
||||
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
|
||||
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
||||
return x_out.type_as(hidden_states)
|
||||
|
||||
query = apply_rotary_emb(query, rotary_emb)
|
||||
key = apply_rotary_emb(key, rotary_emb)
|
||||
|
||||
# I2V task
|
||||
hidden_states_img = None
|
||||
if encoder_hidden_states_img is not None:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
key_img = attn.norm_added_k(key_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
hidden_states_img = F.scaled_dot_product_attention(
|
||||
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
if hidden_states_img is not None:
|
||||
hidden_states = hidden_states + hidden_states_img
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanImageEmbedding(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = FP32LayerNorm(in_features)
|
||||
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
|
||||
self.norm2 = FP32LayerNorm(out_features)
|
||||
|
||||
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm1(encoder_hidden_states_image)
|
||||
hidden_states = self.ff(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_freq_dim: int,
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
image_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||
|
||||
self.image_embedder = None
|
||||
if image_embed_dim is not None:
|
||||
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
||||
|
||||
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
||||
|
||||
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
freqs = []
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
freq = get_1d_rotary_pos_embed(
|
||||
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
|
||||
)
|
||||
freqs.append(freq)
|
||||
self.freqs = torch.cat(freqs, dim=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
self.freqs = self.freqs.to(hidden_states.device)
|
||||
freqs = self.freqs.split_with_sizes(
|
||||
[
|
||||
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
|
||||
self.attention_head_dim // 6,
|
||||
self.attention_head_dim // 6,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
||||
return freqs
|
||||
|
||||
|
||||
class WanTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
added_proj_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
||||
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
||||
hidden_states
|
||||
)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the Wan model.
|
||||
|
||||
Args:
|
||||
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
||||
num_attention_heads (`int`, defaults to `40`):
|
||||
Fixed length for text embeddings.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_dim (`int`, defaults to `512`):
|
||||
Input dimension for text embeddings.
|
||||
freq_dim (`int`, defaults to `256`):
|
||||
Dimension for sinusoidal time embeddings.
|
||||
ffn_dim (`int`, defaults to `13824`):
|
||||
Intermediate dimension in feed-forward network.
|
||||
num_layers (`int`, defaults to `40`):
|
||||
The number of layers of transformer blocks to use.
|
||||
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
||||
Window size for local attention (-1 indicates global attention).
|
||||
cross_attn_norm (`bool`, defaults to `True`):
|
||||
Enable cross-attention normalization.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Enable query/key normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
add_img_emb (`bool`, defaults to `False`):
|
||||
Whether to use img_emb.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
||||
_no_split_modules = ["WanTransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
text_dim: int = 4096,
|
||||
freq_dim: int = 256,
|
||||
ffn_dim: int = 13824,
|
||||
num_layers: int = 40,
|
||||
cross_attn_norm: bool = True,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
eps: float = 1e-6,
|
||||
image_dim: Optional[int] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
rope_max_seq_len: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Condition embeddings
|
||||
# image_embedding_dim=1280 for I2V model
|
||||
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||
dim=inner_dim,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
image_embed_dim=image_dim,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanTransformerBlock(
|
||||
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
||||
timestep, encoder_hidden_states, encoder_hidden_states_image
|
||||
)
|
||||
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
||||
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
||||
|
||||
# 5. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -240,7 +240,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
|
||||
@@ -216,6 +216,11 @@ else:
|
||||
"IFPipeline",
|
||||
"IFSuperResolutionPipeline",
|
||||
]
|
||||
_import_structure["easyanimate"] = [
|
||||
"EasyAnimatePipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
]
|
||||
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
|
||||
_import_structure["hunyuan_video"] = ["HunyuanVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline"]
|
||||
_import_structure["kandinsky"] = [
|
||||
@@ -347,6 +352,7 @@ else:
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline"]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -545,6 +551,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
from .easyanimate import (
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
)
|
||||
from .flux import (
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
@@ -690,6 +701,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
@@ -83,6 +83,7 @@ class AnimateDiffPipeline(
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
@@ -125,6 +125,7 @@ class AnimateDiffControlNetPipeline(
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation with ControlNet guidance.
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -136,6 +136,7 @@ class AnimateDiffSparseControlNetPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
@@ -186,6 +186,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for video-to-video generation.
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
@@ -204,6 +204,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for video-to-video generation with ControlNet guidance.
|
||||
|
||||
@@ -143,13 +143,11 @@ class CogView4Pipeline(DiffusionPipeline):
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. CogView4 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
text_encoder ([`GLMModel`]):
|
||||
Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
|
||||
transformer ([`CogView4Transformer2DModel`]):
|
||||
A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
@@ -215,7 +213,7 @@ class CogView4Pipeline(DiffusionPipeline):
|
||||
)
|
||||
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True
|
||||
text_input_ids.to(self.text_encoder.device), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
@@ -362,10 +360,16 @@ class CogView4Pipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
|
||||
@@ -757,15 +757,9 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
for images_ in image:
|
||||
for image_ in images_:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
# TODO Update for https://github.com/huggingface/diffusers/pull/10723
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if isinstance(controlnet_conditioning_scale, list):
|
||||
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
||||
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
|
||||
@@ -776,8 +770,6 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
||||
" the same length as the number of controlnets"
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if len(control_guidance_start) != len(control_guidance_end):
|
||||
raise ValueError(
|
||||
@@ -808,8 +800,6 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
|
||||
if max(_control_mode) >= _controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Equal number of `image` and `control_mode` elements
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
@@ -823,8 +813,6 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
|
||||
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_mode)")
|
||||
else:
|
||||
assert False
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -1201,18 +1189,6 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else 1
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
if not isinstance(control_image, list):
|
||||
control_image = [control_image]
|
||||
else:
|
||||
@@ -1221,8 +1197,25 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode]
|
||||
|
||||
if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_image = [[item] for item in control_image]
|
||||
control_mode = [[item] for item in control_mode]
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
if isinstance(controlnet_conditioning_scale, float):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
@@ -1357,9 +1350,6 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
control_image = control_images
|
||||
height, width = control_image[0][0].shape[-2:]
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
# 5. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
@@ -1397,7 +1387,7 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetUnionModel) else keeps)
|
||||
controlnet_keep.append(keeps)
|
||||
|
||||
# 7.2 Prepare added time ids & embeddings
|
||||
original_size = original_size or (height, width)
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_easyanimate"] = ["EasyAnimatePipeline"]
|
||||
_import_structure["pipeline_easyanimate_control"] = ["EasyAnimateControlPipeline"]
|
||||
_import_structure["pipeline_easyanimate_inpaint"] = ["EasyAnimateInpaintPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_easyanimate import EasyAnimatePipeline
|
||||
from .pipeline_easyanimate_control import EasyAnimateControlPipeline
|
||||
from .pipeline_easyanimate_inpaint import EasyAnimateInpaintPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
+770
@@ -0,0 +1,770 @@
|
||||
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
Qwen2Tokenizer,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from .pipeline_output import EasyAnimatePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import EasyAnimatePipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh"
|
||||
>>> pipe = EasyAnimatePipeline.from_pretrained(
|
||||
... "alibaba-pai/EasyAnimateV5.1-7b-zh-diffusers", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
>>> prompt = (
|
||||
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
||||
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
||||
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||
... "atmosphere of this unique musical performance."
|
||||
... )
|
||||
>>> sample_size = (512, 512)
|
||||
>>> video = pipe(
|
||||
... prompt=prompt,
|
||||
... guidance_scale=6,
|
||||
... negative_prompt="bad detailed",
|
||||
... height=sample_size[0],
|
||||
... width=sample_size[1],
|
||||
... num_inference_steps=50,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class EasyAnimatePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using EasyAnimate.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKLMagvit`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
||||
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
|
||||
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
||||
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
|
||||
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
|
||||
transformer ([`EasyAnimateTransformer3DModel`]):
|
||||
The EasyAnimate model designed by EasyAnimate Team.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLMagvit,
|
||||
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
|
||||
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
|
||||
transformer: EasyAnimateTransformer3DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.enable_text_attention_mask = (
|
||||
self.transformer.config.enable_text_attention_mask
|
||||
if getattr(self, "transformer", None) is not None
|
||||
else True
|
||||
)
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
dtype (`torch.dtype`):
|
||||
torch dtype
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
|
||||
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
|
||||
"""
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
device = device or self.text_encoder.device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
if isinstance(prompt, str):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}],
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": _prompt}],
|
||||
}
|
||||
for _prompt in prompt
|
||||
]
|
||||
text = [
|
||||
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
|
||||
]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text=text,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
padding_side="right",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs.to(self.text_encoder.device)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
if self.enable_text_attention_mask:
|
||||
# Inference: Generation of the output
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
else:
|
||||
raise ValueError("LLM needs attention_mask")
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
if negative_prompt is not None and isinstance(negative_prompt, str):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": negative_prompt}],
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": _negative_prompt}],
|
||||
}
|
||||
for _negative_prompt in negative_prompt
|
||||
]
|
||||
text = [
|
||||
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
|
||||
]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text=text,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
padding_side="right",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs.to(self.text_encoder.device)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
negative_prompt_attention_mask = text_inputs.attention_mask
|
||||
if self.enable_text_attention_mask:
|
||||
# Inference: Generation of the output
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=negative_prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
else:
|
||||
raise ValueError("LLM needs attention_mask")
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
(num_frames - 1) // self.vae_temporal_compression_ratio + 1,
|
||||
height // self.vae_spatial_compression_ratio,
|
||||
width // self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
if hasattr(self.scheduler, "init_noise_sigma"):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_frames: Optional[int] = 49,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
guidance_rescale: float = 0.0,
|
||||
):
|
||||
r"""
|
||||
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
|
||||
|
||||
Examples:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
|
||||
num_frames (`int`, *optional*):
|
||||
Length of the generated video (in frames).
|
||||
height (`int`, *optional*):
|
||||
Height of the generated image in pixels.
|
||||
width (`int`, *optional*):
|
||||
Width of the generated image in pixels.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
Number of denoising steps during generation. More steps generally yield higher quality images but slow
|
||||
down inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of images to generate for each prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A generator to ensure reproducibility in image generation.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Predefined latent tensors to condition generation.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Embeddings for negative prompts. Overrides string inputs if defined.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the primary prompt embeddings.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for negative prompt embeddings.
|
||||
output_type (`str`, *optional*, defaults to "latent"):
|
||||
Format of the generated output, either as a PIL image or as a NumPy array.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
If `True`, returns a structured output. Otherwise returns a simple tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
Functions called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
||||
Tensor names to be included in callback function calls.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Adjusts noise levels based on guidance scale.
|
||||
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
|
||||
Original dimensions of the output.
|
||||
target_size (`Tuple[int, int]`, *optional*):
|
||||
Desired output dimensions for calculations.
|
||||
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
|
||||
Coordinates for cropping.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 0. default height and width
|
||||
height = int((height // 16) * 16)
|
||||
width = int((width // 16) * 16)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, mu=1
|
||||
)
|
||||
else:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
if hasattr(self.scheduler, "scale_model_input"):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
||||
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
||||
dtype=latent_model_input.dtype
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
t_expand,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
||||
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return EasyAnimatePipelineOutput(frames=video)
|
||||
@@ -0,0 +1,994 @@
|
||||
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
Qwen2Tokenizer,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from .pipeline_output import EasyAnimatePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import EasyAnimateControlPipeline
|
||||
>>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_control import get_video_to_video_latent
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> pipe = EasyAnimateControlPipeline.from_pretrained(
|
||||
... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> control_video = load_video(
|
||||
... "https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control/blob/main/asset/pose.mp4"
|
||||
... )
|
||||
>>> prompt = (
|
||||
... "In this sunlit outdoor garden, a beautiful woman is dressed in a knee-length, sleeveless white dress. "
|
||||
... "The hem of her dress gently sways with her graceful dance, much like a butterfly fluttering in the breeze. "
|
||||
... "Sunlight filters through the leaves, casting dappled shadows that highlight her soft features and clear eyes, "
|
||||
... "making her appear exceptionally elegant. It seems as if every movement she makes speaks of youth and vitality. "
|
||||
... "As she twirls on the grass, her dress flutters, as if the entire garden is rejoicing in her dance. "
|
||||
... "The colorful flowers around her sway in the gentle breeze, with roses, chrysanthemums, and lilies each "
|
||||
... "releasing their fragrances, creating a relaxed and joyful atmosphere."
|
||||
... )
|
||||
>>> sample_size = (672, 384)
|
||||
>>> num_frames = 49
|
||||
|
||||
>>> input_video, _, _ = get_video_to_video_latent(control_video, num_frames, sample_size)
|
||||
>>> video = pipe(
|
||||
... prompt,
|
||||
... num_frames=num_frames,
|
||||
... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.",
|
||||
... height=sample_size[0],
|
||||
... width=sample_size[1],
|
||||
... control_video=input_video,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def preprocess_image(image, sample_size):
|
||||
"""
|
||||
Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
|
||||
"""
|
||||
if isinstance(image, torch.Tensor):
|
||||
# If input is a tensor, assume it's in CHW format and resize using interpolation
|
||||
image = torch.nn.functional.interpolate(
|
||||
image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
elif isinstance(image, Image.Image):
|
||||
# If input is a PIL image, resize and convert to numpy array
|
||||
image = image.resize((sample_size[1], sample_size[0]))
|
||||
image = np.array(image)
|
||||
elif isinstance(image, np.ndarray):
|
||||
# If input is a numpy array, resize using PIL
|
||||
image = Image.fromarray(image).resize((sample_size[1], sample_size[0]))
|
||||
image = np.array(image)
|
||||
else:
|
||||
raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.")
|
||||
|
||||
# Convert to tensor if not already
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1]
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_video_to_video_latent(input_video, num_frames, sample_size, validation_video_mask=None, ref_image=None):
|
||||
if input_video is not None:
|
||||
# Convert each frame in the list to tensor
|
||||
input_video = [preprocess_image(frame, sample_size=sample_size) for frame in input_video]
|
||||
|
||||
# Stack all frames into a single tensor (F, C, H, W)
|
||||
input_video = torch.stack(input_video)[:num_frames]
|
||||
|
||||
# Add batch dimension (B, F, C, H, W)
|
||||
input_video = input_video.permute(1, 0, 2, 3).unsqueeze(0)
|
||||
|
||||
if validation_video_mask is not None:
|
||||
# Handle mask input
|
||||
validation_video_mask = preprocess_image(validation_video_mask, size=sample_size)
|
||||
input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255)
|
||||
|
||||
# Adjust mask dimensions to match video
|
||||
input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
|
||||
input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
|
||||
input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
|
||||
else:
|
||||
input_video_mask = torch.zeros_like(input_video[:, :1])
|
||||
input_video_mask[:, :, :] = 255
|
||||
else:
|
||||
input_video, input_video_mask = None, None
|
||||
|
||||
if ref_image is not None:
|
||||
# Convert reference image to tensor
|
||||
ref_image = preprocess_image(ref_image, size=sample_size)
|
||||
ref_image = ref_image.permute(1, 0, 2, 3).unsqueeze(0) # Add batch dimension (B, C, H, W)
|
||||
else:
|
||||
ref_image = None
|
||||
|
||||
return input_video, input_video_mask, ref_image
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Resize mask information in magvit
|
||||
def resize_mask(mask, latent, process_first_frame_only=True):
|
||||
latent_size = latent.size()
|
||||
|
||||
if process_first_frame_only:
|
||||
target_size = list(latent_size[2:])
|
||||
target_size[0] = 1
|
||||
first_frame_resized = F.interpolate(
|
||||
mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False
|
||||
)
|
||||
|
||||
target_size = list(latent_size[2:])
|
||||
target_size[0] = target_size[0] - 1
|
||||
if target_size[0] != 0:
|
||||
remaining_frames_resized = F.interpolate(
|
||||
mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False
|
||||
)
|
||||
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
||||
else:
|
||||
resized_mask = first_frame_resized
|
||||
else:
|
||||
target_size = list(latent_size[2:])
|
||||
resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False)
|
||||
return resized_mask
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class EasyAnimateControlPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using EasyAnimate.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKLMagvit`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
||||
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
|
||||
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
||||
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
|
||||
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
|
||||
transformer ([`EasyAnimateTransformer3DModel`]):
|
||||
The EasyAnimate model designed by EasyAnimate Team.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLMagvit,
|
||||
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
|
||||
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
|
||||
transformer: EasyAnimateTransformer3DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.enable_text_attention_mask = (
|
||||
self.transformer.config.enable_text_attention_mask
|
||||
if getattr(self, "transformer", None) is not None
|
||||
else True
|
||||
)
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_spatial_compression_ratio,
|
||||
do_normalize=False,
|
||||
do_binarize=True,
|
||||
do_convert_grayscale=True,
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
|
||||
# Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
dtype (`torch.dtype`):
|
||||
torch dtype
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
|
||||
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
|
||||
"""
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
device = device or self.text_encoder.device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
if isinstance(prompt, str):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}],
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": _prompt}],
|
||||
}
|
||||
for _prompt in prompt
|
||||
]
|
||||
text = [
|
||||
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
|
||||
]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text=text,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
padding_side="right",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs.to(self.text_encoder.device)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
if self.enable_text_attention_mask:
|
||||
# Inference: Generation of the output
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
else:
|
||||
raise ValueError("LLM needs attention_mask")
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
if negative_prompt is not None and isinstance(negative_prompt, str):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": negative_prompt}],
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": _negative_prompt}],
|
||||
}
|
||||
for _negative_prompt in negative_prompt
|
||||
]
|
||||
text = [
|
||||
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
|
||||
]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text=text,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
padding_side="right",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs.to(self.text_encoder.device)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
negative_prompt_attention_mask = text_inputs.attention_mask
|
||||
if self.enable_text_attention_mask:
|
||||
# Inference: Generation of the output
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=negative_prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
else:
|
||||
raise ValueError("LLM needs attention_mask")
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
(num_frames - 1) // self.vae_temporal_compression_ratio + 1,
|
||||
height // self.vae_spatial_compression_ratio,
|
||||
width // self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
if hasattr(self.scheduler, "init_noise_sigma"):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_control_latents(
|
||||
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
||||
):
|
||||
# resize the control to latents shape as we concatenate the control to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
|
||||
if control is not None:
|
||||
control = control.to(device=device, dtype=dtype)
|
||||
bs = 1
|
||||
new_control = []
|
||||
for i in range(0, control.shape[0], bs):
|
||||
control_bs = control[i : i + bs]
|
||||
control_bs = self.vae.encode(control_bs)[0]
|
||||
control_bs = control_bs.mode()
|
||||
new_control.append(control_bs)
|
||||
control = torch.cat(new_control, dim=0)
|
||||
control = control * self.vae.config.scaling_factor
|
||||
|
||||
if control_image is not None:
|
||||
control_image = control_image.to(device=device, dtype=dtype)
|
||||
bs = 1
|
||||
new_control_pixel_values = []
|
||||
for i in range(0, control_image.shape[0], bs):
|
||||
control_pixel_values_bs = control_image[i : i + bs]
|
||||
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
|
||||
control_pixel_values_bs = control_pixel_values_bs.mode()
|
||||
new_control_pixel_values.append(control_pixel_values_bs)
|
||||
control_image_latents = torch.cat(new_control_pixel_values, dim=0)
|
||||
control_image_latents = control_image_latents * self.vae.config.scaling_factor
|
||||
else:
|
||||
control_image_latents = None
|
||||
|
||||
return control, control_image_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_frames: Optional[int] = 49,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
control_video: Union[torch.FloatTensor] = None,
|
||||
control_camera_video: Union[torch.FloatTensor] = None,
|
||||
ref_image: Union[torch.FloatTensor] = None,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
guidance_rescale: float = 0.0,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
r"""
|
||||
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
|
||||
|
||||
Examples:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
|
||||
num_frames (`int`, *optional*):
|
||||
Length of the generated video (in frames).
|
||||
height (`int`, *optional*):
|
||||
Height of the generated image in pixels.
|
||||
width (`int`, *optional*):
|
||||
Width of the generated image in pixels.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
Number of denoising steps during generation. More steps generally yield higher quality images but slow
|
||||
down inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of images to generate for each prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A generator to ensure reproducibility in image generation.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Predefined latent tensors to condition generation.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Embeddings for negative prompts. Overrides string inputs if defined.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the primary prompt embeddings.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for negative prompt embeddings.
|
||||
output_type (`str`, *optional*, defaults to "latent"):
|
||||
Format of the generated output, either as a PIL image or as a NumPy array.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
If `True`, returns a structured output. Otherwise returns a simple tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
Functions called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
||||
Tensor names to be included in callback function calls.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Adjusts noise levels based on guidance scale.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 0. default height and width
|
||||
height = int((height // 16) * 16)
|
||||
width = int((width // 16) * 16)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
text_encoder_index=0,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, mu=1
|
||||
)
|
||||
else:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if control_camera_video is not None:
|
||||
control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True)
|
||||
control_video_latents = control_video_latents * 6
|
||||
control_latents = (
|
||||
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
|
||||
).to(device, dtype)
|
||||
elif control_video is not None:
|
||||
batch_size, channels, num_frames, height_video, width_video = control_video.shape
|
||||
control_video = self.image_processor.preprocess(
|
||||
control_video.permute(0, 2, 1, 3, 4).reshape(
|
||||
batch_size * num_frames, channels, height_video, width_video
|
||||
),
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
control_video = control_video.to(dtype=torch.float32)
|
||||
control_video = control_video.reshape(batch_size, num_frames, channels, height, width).permute(
|
||||
0, 2, 1, 3, 4
|
||||
)
|
||||
control_video_latents = self.prepare_control_latents(
|
||||
None,
|
||||
control_video,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
self.do_classifier_free_guidance,
|
||||
)[1]
|
||||
control_latents = (
|
||||
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
|
||||
).to(device, dtype)
|
||||
else:
|
||||
control_video_latents = torch.zeros_like(latents).to(device, dtype)
|
||||
control_latents = (
|
||||
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
|
||||
).to(device, dtype)
|
||||
|
||||
if ref_image is not None:
|
||||
batch_size, channels, num_frames, height_video, width_video = ref_image.shape
|
||||
ref_image = self.image_processor.preprocess(
|
||||
ref_image.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video),
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
ref_image = ref_image.to(dtype=torch.float32)
|
||||
ref_image = ref_image.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
||||
|
||||
ref_image_latents = self.prepare_control_latents(
|
||||
None,
|
||||
ref_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
self.do_classifier_free_guidance,
|
||||
)[1]
|
||||
|
||||
ref_image_latents_conv_in = torch.zeros_like(latents)
|
||||
if latents.size()[2] != 1:
|
||||
ref_image_latents_conv_in[:, :, :1] = ref_image_latents
|
||||
ref_image_latents_conv_in = (
|
||||
torch.cat([ref_image_latents_conv_in] * 2)
|
||||
if self.do_classifier_free_guidance
|
||||
else ref_image_latents_conv_in
|
||||
).to(device, dtype)
|
||||
control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
|
||||
else:
|
||||
ref_image_latents_conv_in = torch.zeros_like(latents)
|
||||
ref_image_latents_conv_in = (
|
||||
torch.cat([ref_image_latents_conv_in] * 2)
|
||||
if self.do_classifier_free_guidance
|
||||
else ref_image_latents_conv_in
|
||||
).to(device, dtype)
|
||||
control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
||||
|
||||
# To latents.device
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
if hasattr(self.scheduler, "scale_model_input"):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
||||
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
||||
dtype=latent_model_input.dtype
|
||||
)
|
||||
# predict the noise residual
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
t_expand,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
control_latents=control_latents,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
||||
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# Convert to tensor
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return EasyAnimatePipelineOutput(frames=video)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class EasyAnimatePipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for EasyAnimate pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -438,7 +438,6 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
|
||||
@@ -477,7 +477,6 @@ class FluxControlInpaintPipeline(
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
|
||||
@@ -18,14 +18,16 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
@@ -171,7 +173,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin):
|
||||
r"""
|
||||
The Flux pipeline for text-to-image generation.
|
||||
|
||||
@@ -198,8 +200,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
||||
_optional_components = ["image_encoder", "feature_extractor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -214,6 +216,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
controlnet: Union[
|
||||
FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
|
||||
],
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
@@ -228,6 +232,8 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
controlnet=controlnet,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
@@ -413,14 +419,67 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image in ip_adapter_image:
|
||||
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
else:
|
||||
if not isinstance(ip_adapter_image_embeds, list):
|
||||
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
||||
|
||||
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for single_image_embeds in image_embeds:
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
@@ -455,10 +514,33 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
@@ -597,6 +679,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
true_cfg_scale: float = 1.0,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
@@ -612,6 +697,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -679,6 +770,17 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
negative_ip_adapter_image:
|
||||
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -727,8 +829,12 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
@@ -752,6 +858,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
@@ -766,6 +873,21 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
if do_true_cfg:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
_,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_2=negative_prompt_2,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 3. Prepare control image
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
@@ -899,12 +1021,43 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
|
||||
|
||||
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
||||
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
||||
):
|
||||
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
||||
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
||||
):
|
||||
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
|
||||
image_embeds = None
|
||||
negative_image_embeds = None
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
||||
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
negative_ip_adapter_image,
|
||||
negative_ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
@@ -960,6 +1113,25 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
|
||||
controlnet_blocks_repeat=controlnet_blocks_repeat,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
controlnet_blocks_repeat=controlnet_blocks_repeat,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -17,10 +17,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -159,7 +166,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin):
|
||||
r"""
|
||||
The Flux pipeline for image inpainting.
|
||||
|
||||
@@ -186,8 +193,8 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
||||
_optional_components = ["image_encoder", "feature_extractor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -199,6 +206,8 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: T5TokenizerFast,
|
||||
transformer: FluxTransformer2DModel,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -210,6 +219,8 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
@@ -395,6 +406,55 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image in ip_adapter_image:
|
||||
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
else:
|
||||
if not isinstance(ip_adapter_image_embeds, list):
|
||||
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
||||
|
||||
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for single_image_embeds in image_embeds:
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
@@ -429,8 +489,12 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
strength,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
@@ -468,10 +532,33 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
@@ -586,6 +673,9 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
true_cfg_scale: float = 1.0,
|
||||
image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
@@ -598,6 +688,12 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -659,6 +755,17 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
negative_ip_adapter_image:
|
||||
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -697,8 +804,12 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
strength,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
@@ -724,6 +835,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
@@ -738,6 +850,21 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
if do_true_cfg:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
_,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_2=negative_prompt_2,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4.Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
@@ -791,12 +918,43 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
||||
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
||||
):
|
||||
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
||||
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
||||
):
|
||||
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
|
||||
image_embeds = None
|
||||
negative_image_embeds = None
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
||||
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
negative_ip_adapter_image,
|
||||
negative_ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
noise_pred = self.transformer(
|
||||
@@ -811,6 +969,22 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -18,10 +18,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -156,7 +163,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterMixin):
|
||||
r"""
|
||||
The Flux pipeline for image inpainting.
|
||||
|
||||
@@ -183,8 +190,8 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
||||
_optional_components = ["image_encoder", "feature_extractor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -196,6 +203,8 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: T5TokenizerFast,
|
||||
transformer: FluxTransformer2DModel,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -207,6 +216,8 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
@@ -400,6 +411,55 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
return image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
||||
):
|
||||
image_embeds = []
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image in ip_adapter_image:
|
||||
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
else:
|
||||
if not isinstance(ip_adapter_image_embeds, list):
|
||||
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
||||
|
||||
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for single_image_embeds in image_embeds:
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
|
||||
return ip_adapter_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
@@ -437,8 +497,12 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
height,
|
||||
width,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
max_sequence_length=None,
|
||||
@@ -477,10 +541,33 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
@@ -684,6 +771,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
true_cfg_scale: float = 1.0,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
masked_image_latents: PipelineImageInput = None,
|
||||
@@ -699,6 +789,12 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -777,6 +873,17 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
negative_ip_adapter_image:
|
||||
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -818,8 +925,12 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
height,
|
||||
width,
|
||||
output_type=output_type,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop=padding_mask_crop,
|
||||
max_sequence_length=max_sequence_length,
|
||||
@@ -856,6 +967,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
@@ -870,6 +982,21 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
if do_true_cfg:
|
||||
(
|
||||
negative_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
_,
|
||||
) = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_2=negative_prompt_2,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4.Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
@@ -946,12 +1073,43 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
||||
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
||||
):
|
||||
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
||||
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
||||
):
|
||||
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
|
||||
image_embeds = None
|
||||
negative_image_embeds = None
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
||||
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
negative_ip_adapter_image,
|
||||
negative_ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
noise_pred = self.transformer(
|
||||
@@ -966,6 +1124,22 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=negative_pooled_prompt_embeds,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_wan"] = ["WanPipeline"]
|
||||
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_wan import WanPipeline
|
||||
from .pipeline_wan_i2v import WanImageToVideoPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class WanPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Wan pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -0,0 +1,574 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import WanPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import AutoencoderKLWan, WanPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
|
||||
>>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A cat walks on the grass, realistic"
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=480,
|
||||
... width=832,
|
||||
... num_frames=81,
|
||||
... guidance_scale=5.0,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=15)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using Wan.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `81`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~WanPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return WanPipelineOutput(frames=video)
|
||||
@@ -0,0 +1,650 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import ftfy
|
||||
import PIL
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import WanPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||
>>> from diffusers.utils import export_to_video, load_image
|
||||
|
||||
>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-1.3B-720P-Diffusers
|
||||
>>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> height, width = 480, 832
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
... ).resize((width, height))
|
||||
>>> prompt = (
|
||||
... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
||||
... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
||||
... )
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
>>> output = pipe(
|
||||
... image=image, prompt=prompt, negative_prompt=negative_prompt, num_frames=81, guidance_scale=5.0
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=15)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for image-to-video generation using Wan.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
image_encoder ([`CLIPVisionModel`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
|
||||
the
|
||||
[clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
|
||||
variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
image_processor: CLIPImageProcessor,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
image_encoder=image_encoder,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
self.image_processor = image_processor
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_image(self, image: PipelineImageInput):
|
||||
image = self.image_processor(images=image, return_tensors="pt").to(self.device)
|
||||
image_embeds = self.image_encoder(**image, output_hidden_states=True)
|
||||
return image_embeds.hidden_states[-1]
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_height = height // self.vae_scale_factor_spatial
|
||||
latent_width = width // self.vae_scale_factor_spatial
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
image = image.unsqueeze(2)
|
||||
video_condition = torch.cat(
|
||||
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
||||
)
|
||||
video_condition = video_condition.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list):
|
||||
latent_condition = [retrieve_latents(self.vae.encode(video_condition), g) for g in generator]
|
||||
latents = latent_condition = torch.cat(latent_condition)
|
||||
else:
|
||||
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
|
||||
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(latent_condition.device)
|
||||
|
||||
return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
height (`int`, defaults to `480`):
|
||||
The height of the generated video.
|
||||
width (`int`, defaults to `832`):
|
||||
The width of the generated video.
|
||||
num_frames (`int`, defaults to `81`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to `512`):
|
||||
The maximum sequence length of the prompt.
|
||||
shift (`float`, *optional*, defaults to `5.0`):
|
||||
The shift of the flow.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~WanPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Encode image embedding
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
image_embeds = self.encode_image(image)
|
||||
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
||||
image_embeds = image_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.z_dim
|
||||
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
|
||||
latents, condition = self.prepare_latents(
|
||||
image,
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_image=image_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states_image=image_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return WanPipelineOutput(frames=video)
|
||||
@@ -47,6 +47,16 @@ class QuantizationMethod(str, Enum):
|
||||
TORCHAO = "torchao"
|
||||
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.quantization.quant_primitives import MappingType
|
||||
|
||||
class TorchAoJSONEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, MappingType):
|
||||
return obj.name
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizationConfigMixin:
|
||||
"""
|
||||
@@ -673,4 +683,6 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
```
|
||||
"""
|
||||
config_dict = self.to_dict()
|
||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||
return (
|
||||
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
|
||||
)
|
||||
|
||||
@@ -116,6 +116,7 @@ from .peft_utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
|
||||
from .remote_utils import remote_decode
|
||||
from .state_dict_utils import (
|
||||
convert_all_state_dict_to_peft,
|
||||
convert_state_dict_to_diffusers,
|
||||
|
||||
@@ -171,6 +171,21 @@ class AutoencoderKLLTXVideo(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLMagvit(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLMochi(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -201,6 +216,21 @@ class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLWan(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderOobleck(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -381,6 +411,21 @@ class DiTTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class EasyAnimateTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FluxControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -966,6 +1011,21 @@ class VQModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class WanTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def get_constant_schedule(*args, **kwargs):
|
||||
requires_backends(get_constant_schedule, ["torch"])
|
||||
|
||||
|
||||
@@ -407,6 +407,51 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class EasyAnimateControlPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class EasyAnimateInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class EasyAnimatePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxControlImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -2597,6 +2642,36 @@ class VQDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WuerstchenCombinedPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -0,0 +1,334 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import json
|
||||
from typing import List, Literal, Optional, Union, cast
|
||||
|
||||
import requests
|
||||
|
||||
from .deprecation_utils import deprecate
|
||||
from .import_utils import is_safetensors_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from ..image_processor import VaeImageProcessor
|
||||
from ..video_processor import VideoProcessor
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
|
||||
DTYPE_MAP = {
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"uint8": torch.uint8,
|
||||
}
|
||||
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def detect_image_type(data: bytes) -> str:
|
||||
if data.startswith(b"\xff\xd8"):
|
||||
return "jpeg"
|
||||
elif data.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "png"
|
||||
elif data.startswith(b"GIF87a") or data.startswith(b"GIF89a"):
|
||||
return "gif"
|
||||
elif data.startswith(b"BM"):
|
||||
return "bmp"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def check_inputs(
|
||||
endpoint: str,
|
||||
tensor: "torch.Tensor",
|
||||
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
|
||||
do_scaling: bool = True,
|
||||
scaling_factor: Optional[float] = None,
|
||||
shift_factor: Optional[float] = None,
|
||||
output_type: Literal["mp4", "pil", "pt"] = "pil",
|
||||
return_type: Literal["mp4", "pil", "pt"] = "pil",
|
||||
image_format: Literal["png", "jpg"] = "jpg",
|
||||
partial_postprocess: bool = False,
|
||||
input_tensor_type: Literal["binary"] = "binary",
|
||||
output_tensor_type: Literal["binary"] = "binary",
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
):
|
||||
if tensor.ndim == 3 and height is None and width is None:
|
||||
raise ValueError("`height` and `width` required for packed latents.")
|
||||
if (
|
||||
output_type == "pt"
|
||||
and return_type == "pil"
|
||||
and not partial_postprocess
|
||||
and not isinstance(processor, (VaeImageProcessor, VideoProcessor))
|
||||
):
|
||||
raise ValueError("`processor` is required.")
|
||||
if do_scaling and scaling_factor is None:
|
||||
deprecate(
|
||||
"do_scaling",
|
||||
"1.0.0",
|
||||
"`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
|
||||
standard_warn=False,
|
||||
)
|
||||
|
||||
|
||||
def postprocess(
|
||||
response: requests.Response,
|
||||
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
|
||||
output_type: Literal["mp4", "pil", "pt"] = "pil",
|
||||
return_type: Literal["mp4", "pil", "pt"] = "pil",
|
||||
partial_postprocess: bool = False,
|
||||
):
|
||||
if output_type == "pt" or (output_type == "pil" and processor is not None):
|
||||
output_tensor = response.content
|
||||
parameters = response.headers
|
||||
shape = json.loads(parameters["shape"])
|
||||
dtype = parameters["dtype"]
|
||||
torch_dtype = DTYPE_MAP[dtype]
|
||||
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
|
||||
if output_type == "pt":
|
||||
if partial_postprocess:
|
||||
if return_type == "pil":
|
||||
output = [Image.fromarray(image.numpy()) for image in output_tensor]
|
||||
if len(output) == 1:
|
||||
output = output[0]
|
||||
elif return_type == "pt":
|
||||
output = output_tensor
|
||||
else:
|
||||
if processor is None or return_type == "pt":
|
||||
output = output_tensor
|
||||
else:
|
||||
if isinstance(processor, VideoProcessor):
|
||||
output = cast(
|
||||
List[Image.Image],
|
||||
processor.postprocess_video(output_tensor, output_type="pil")[0],
|
||||
)
|
||||
else:
|
||||
output = cast(
|
||||
Image.Image,
|
||||
processor.postprocess(output_tensor, output_type="pil")[0],
|
||||
)
|
||||
elif output_type == "pil" and return_type == "pil" and processor is None:
|
||||
output = Image.open(io.BytesIO(response.content)).convert("RGB")
|
||||
detected_format = detect_image_type(response.content)
|
||||
output.format = detected_format
|
||||
elif output_type == "pil" and processor is not None:
|
||||
if return_type == "pil":
|
||||
output = [
|
||||
Image.fromarray(image)
|
||||
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
|
||||
]
|
||||
elif return_type == "pt":
|
||||
output = output_tensor
|
||||
elif output_type == "mp4" and return_type == "mp4":
|
||||
output = response.content
|
||||
return output
|
||||
|
||||
|
||||
def prepare(
|
||||
tensor: "torch.Tensor",
|
||||
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
|
||||
do_scaling: bool = True,
|
||||
scaling_factor: Optional[float] = None,
|
||||
shift_factor: Optional[float] = None,
|
||||
output_type: Literal["mp4", "pil", "pt"] = "pil",
|
||||
image_format: Literal["png", "jpg"] = "jpg",
|
||||
partial_postprocess: bool = False,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
):
|
||||
headers = {}
|
||||
parameters = {
|
||||
"image_format": image_format,
|
||||
"output_type": output_type,
|
||||
"partial_postprocess": partial_postprocess,
|
||||
"shape": list(tensor.shape),
|
||||
"dtype": str(tensor.dtype).split(".")[-1],
|
||||
}
|
||||
if do_scaling and scaling_factor is not None:
|
||||
parameters["scaling_factor"] = scaling_factor
|
||||
if do_scaling and shift_factor is not None:
|
||||
parameters["shift_factor"] = shift_factor
|
||||
if do_scaling and scaling_factor is None:
|
||||
parameters["do_scaling"] = do_scaling
|
||||
elif do_scaling and scaling_factor is None and shift_factor is None:
|
||||
parameters["do_scaling"] = do_scaling
|
||||
if height is not None and width is not None:
|
||||
parameters["height"] = height
|
||||
parameters["width"] = width
|
||||
headers["Content-Type"] = "tensor/binary"
|
||||
headers["Accept"] = "tensor/binary"
|
||||
if output_type == "pil" and image_format == "jpg" and processor is None:
|
||||
headers["Accept"] = "image/jpeg"
|
||||
elif output_type == "pil" and image_format == "png" and processor is None:
|
||||
headers["Accept"] = "image/png"
|
||||
elif output_type == "mp4":
|
||||
headers["Accept"] = "text/plain"
|
||||
tensor_data = safetensors.torch._tobytes(tensor, "tensor")
|
||||
return {"data": tensor_data, "params": parameters, "headers": headers}
|
||||
|
||||
|
||||
def remote_decode(
|
||||
endpoint: str,
|
||||
tensor: "torch.Tensor",
|
||||
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
|
||||
do_scaling: bool = True,
|
||||
scaling_factor: Optional[float] = None,
|
||||
shift_factor: Optional[float] = None,
|
||||
output_type: Literal["mp4", "pil", "pt"] = "pil",
|
||||
return_type: Literal["mp4", "pil", "pt"] = "pil",
|
||||
image_format: Literal["png", "jpg"] = "jpg",
|
||||
partial_postprocess: bool = False,
|
||||
input_tensor_type: Literal["binary"] = "binary",
|
||||
output_tensor_type: Literal["binary"] = "binary",
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
) -> Union[Image.Image, List[Image.Image], bytes, "torch.Tensor"]:
|
||||
"""
|
||||
Hugging Face Hybrid Inference that allow running VAE decode remotely.
|
||||
|
||||
Args:
|
||||
endpoint (`str`):
|
||||
Endpoint for Remote Decode.
|
||||
tensor (`torch.Tensor`):
|
||||
Tensor to be decoded.
|
||||
processor (`VaeImageProcessor` or `VideoProcessor`, *optional*):
|
||||
Used with `return_type="pt"`, and `return_type="pil"` for Video models.
|
||||
do_scaling (`bool`, default `True`, *optional*):
|
||||
**DEPRECATED**. **pass `scaling_factor`/`shift_factor` instead.** **still set
|
||||
do_scaling=None/do_scaling=False for no scaling until option is removed** When `True` scaling e.g. `latents
|
||||
/ self.vae.config.scaling_factor` is applied remotely. If `False`, input must be passed with scaling
|
||||
applied.
|
||||
scaling_factor (`float`, *optional*):
|
||||
Scaling is applied when passed e.g. [`latents /
|
||||
self.vae.config.scaling_factor`](https://github.com/huggingface/diffusers/blob/7007febae5cff000d4df9059d9cf35133e8b2ca9/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L1083C37-L1083C77).
|
||||
- SD v1: 0.18215
|
||||
- SD XL: 0.13025
|
||||
- Flux: 0.3611
|
||||
If `None`, input must be passed with scaling applied.
|
||||
shift_factor (`float`, *optional*):
|
||||
Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`.
|
||||
- Flux: 0.1159
|
||||
If `None`, input must be passed with scaling applied.
|
||||
output_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
|
||||
**Endpoint** output type. Subject to change. Report feedback on preferred type.
|
||||
|
||||
`"mp4": Supported by video models. Endpoint returns `bytes` of video. `"pil"`: Supported by image and video
|
||||
models.
|
||||
Image models: Endpoint returns `bytes` of an image in `image_format`. Video models: Endpoint returns
|
||||
`torch.Tensor` with partial `postprocessing` applied.
|
||||
Requires `processor` as a flag (any `None` value will work).
|
||||
`"pt"`: Support by image and video models. Endpoint returns `torch.Tensor`.
|
||||
With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor.
|
||||
|
||||
Recommendations:
|
||||
`"pt"` with `partial_postprocess=True` is the smallest transfer for full quality. `"pt"` with
|
||||
`partial_postprocess=False` is the most compatible with third party code. `"pil"` with
|
||||
`image_format="jpg"` is the smallest transfer overall.
|
||||
|
||||
return_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
|
||||
**Function** return type.
|
||||
|
||||
`"mp4": Function returns `bytes` of video. `"pil"`: Function returns `PIL.Image.Image`.
|
||||
With `output_type="pil" no further processing is applied. With `output_type="pt" a `PIL.Image.Image` is
|
||||
created.
|
||||
`partial_postprocess=False` `processor` is required. `partial_postprocess=True` `processor` is
|
||||
**not** required.
|
||||
`"pt"`: Function returns `torch.Tensor`.
|
||||
`processor` is **not** required. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
|
||||
denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.
|
||||
|
||||
image_format (`"png"` or `"jpg"`, default `jpg`):
|
||||
Used with `output_type="pil"`. Endpoint returns `jpg` or `png`.
|
||||
|
||||
partial_postprocess (`bool`, default `False`):
|
||||
Used with `output_type="pt"`. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
|
||||
denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.
|
||||
|
||||
input_tensor_type (`"binary"`, default `"binary"`):
|
||||
Tensor transfer type.
|
||||
|
||||
output_tensor_type (`"binary"`, default `"binary"`):
|
||||
Tensor transfer type.
|
||||
|
||||
height (`int`, **optional**):
|
||||
Required for `"packed"` latents.
|
||||
|
||||
width (`int`, **optional**):
|
||||
Required for `"packed"` latents.
|
||||
|
||||
Returns:
|
||||
output (`Image.Image` or `List[Image.Image]` or `bytes` or `torch.Tensor`).
|
||||
"""
|
||||
if input_tensor_type == "base64":
|
||||
deprecate(
|
||||
"input_tensor_type='base64'",
|
||||
"1.0.0",
|
||||
"input_tensor_type='base64' is deprecated. Using `binary`.",
|
||||
standard_warn=False,
|
||||
)
|
||||
input_tensor_type = "binary"
|
||||
if output_tensor_type == "base64":
|
||||
deprecate(
|
||||
"output_tensor_type='base64'",
|
||||
"1.0.0",
|
||||
"output_tensor_type='base64' is deprecated. Using `binary`.",
|
||||
standard_warn=False,
|
||||
)
|
||||
output_tensor_type = "binary"
|
||||
check_inputs(
|
||||
endpoint,
|
||||
tensor,
|
||||
processor,
|
||||
do_scaling,
|
||||
scaling_factor,
|
||||
shift_factor,
|
||||
output_type,
|
||||
return_type,
|
||||
image_format,
|
||||
partial_postprocess,
|
||||
input_tensor_type,
|
||||
output_tensor_type,
|
||||
height,
|
||||
width,
|
||||
)
|
||||
kwargs = prepare(
|
||||
tensor=tensor,
|
||||
processor=processor,
|
||||
do_scaling=do_scaling,
|
||||
scaling_factor=scaling_factor,
|
||||
shift_factor=shift_factor,
|
||||
output_type=output_type,
|
||||
image_format=image_format,
|
||||
partial_postprocess=partial_postprocess,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
response = requests.post(endpoint, **kwargs)
|
||||
if not response.ok:
|
||||
raise RuntimeError(response.json())
|
||||
output = postprocess(
|
||||
response=response,
|
||||
processor=processor,
|
||||
output_type=output_type,
|
||||
return_type=return_type,
|
||||
partial_postprocess=partial_postprocess,
|
||||
)
|
||||
return output
|
||||
@@ -33,11 +33,12 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
load_image,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -101,7 +102,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
# Keeping this test here makes sense because it doesn't look any integration
|
||||
# (value assertions on logits).
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_integration_move_lora_cpu(self):
|
||||
path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
lora_id = "takuma104/lora-test-text-encoder-lora-target"
|
||||
@@ -158,7 +159,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_integration_move_lora_dora_cpu(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
@@ -209,18 +210,18 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
class LoraIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_integration_logits_with_scale(self):
|
||||
path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
@@ -378,7 +379,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
|
||||
lora_filename = "light_and_shadow.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
@@ -400,7 +401,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.enable_sequential_cpu_offload(device=torch_device)
|
||||
lora_model_id = "hf-internal-testing/civitai-light-shadow-lora"
|
||||
lora_filename = "light_and_shadow.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
@@ -656,7 +657,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
See: https://github.com/huggingface/diffusers/issues/5606
|
||||
"""
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
|
||||
pipeline.enable_sequential_cpu_offload()
|
||||
pipeline.enable_sequential_cpu_offload(device=torch_device)
|
||||
civitai_path = hf_hub_download("ybelkada/test-ahi-civitai", "ahi_lora_weights.safetensors")
|
||||
pipeline.load_lora_weights(civitai_path, adapter_name="ahri")
|
||||
|
||||
|
||||
@@ -30,12 +30,13 @@ from diffusers import (
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_flaky,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_peft_backend,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -93,7 +94,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def output_shape(self):
|
||||
return (1, 32, 32, 3)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_sd3_lora(self):
|
||||
"""
|
||||
Test loading the loras that are saved with the diffusers and peft formats.
|
||||
@@ -135,7 +136,7 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@pytest.mark.big_gpu_with_torch_cuda
|
||||
@@ -146,12 +147,12 @@ class SD3LoraIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
init_image = load_image(
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
|
||||
scheduler_kwargs = {}
|
||||
|
||||
transformer_kwargs = {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"text_dim": 32,
|
||||
"freq_dim": 256,
|
||||
"ffn_dim": 32,
|
||||
"num_layers": 2,
|
||||
"cross_attn_norm": True,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 32,
|
||||
}
|
||||
transformer_cls = WanTransformer3DModel
|
||||
vae_kwargs = {
|
||||
"base_dim": 3,
|
||||
"z_dim": 16,
|
||||
"dim_mult": [1, 1, 1, 1],
|
||||
"num_res_blocks": 1,
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
vae_cls = AutoencoderKLWan
|
||||
has_two_text_encoders = True
|
||||
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
|
||||
text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 9, 32, 32, 3)
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 16
|
||||
num_channels = 4
|
||||
num_frames = 9
|
||||
num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
|
||||
sizes = (4, 4)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "",
|
||||
"num_frames": num_frames,
|
||||
"num_inference_steps": 1,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": sequence_length,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
+8
-2
@@ -1594,11 +1594,17 @@ class PeftLoraLoaderMixinTests:
|
||||
].weight += float("inf")
|
||||
else:
|
||||
named_modules = [name for name, _ in pipe.transformer.named_modules()]
|
||||
tower_name = (
|
||||
"transformer_blocks"
|
||||
if any(name == "transformer_blocks" for name in named_modules)
|
||||
else "blocks"
|
||||
)
|
||||
transformer_tower = getattr(pipe.transformer, tower_name)
|
||||
has_attn1 = any("attn1" in name for name in named_modules)
|
||||
if has_attn1:
|
||||
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
else:
|
||||
pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLMagvit
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLMagvit
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_magvit_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"out_channels": 3,
|
||||
"block_out_channels": [8, 8, 8, 8],
|
||||
"down_block_types": [
|
||||
"SpatialDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
],
|
||||
"up_block_types": [
|
||||
"SpatialUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
],
|
||||
"layers_per_block": 1,
|
||||
"norm_num_groups": 8,
|
||||
"spatial_group_norm": True,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
height = 16
|
||||
width = 16
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_magvit_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"EasyAnimateEncoder", "EasyAnimateDecoder"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Not quite sure why this test fails. Revisit later.")
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
@@ -0,0 +1,79 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLWan
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLWan
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_wan_config(self):
|
||||
return {
|
||||
"base_dim": 3,
|
||||
"z_dim": 16,
|
||||
"dim_mult": [1, 1, 1, 1],
|
||||
"num_res_blocks": 1,
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Gradient checkpointing has not been implemented yet")
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
@@ -739,8 +739,14 @@ class ModelTesterMixin:
|
||||
model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
if (
|
||||
hasattr(self.model_class, "_keep_in_fp32_modules")
|
||||
and self.model_class._keep_in_fp32_modules is None
|
||||
):
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
|
||||
)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self, expected_max_diff=1e-5):
|
||||
if self.forward_requires_fresh_args:
|
||||
@@ -987,6 +993,10 @@ class ModelTesterMixin:
|
||||
continue
|
||||
if name in skip:
|
||||
continue
|
||||
# TODO(aryan): remove the below lines after looking into easyanimate transformer a little more
|
||||
# It currently errors out the gradient checkpointing test because the gradients for attn2.to_out is None
|
||||
if param.grad is None:
|
||||
continue
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
|
||||
@@ -1169,17 +1179,16 @@ class ModelTesterMixin:
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
# Force disk offload by setting very small CPU memory
|
||||
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
# This errors out because it's missing an offload folder
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
|
||||
)
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import EasyAnimateTransformer3DModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class EasyAnimateTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = EasyAnimateTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
embedding_dim = 16
|
||||
sequence_length = 16
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"timestep_cond": None,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_t5": None,
|
||||
"inpaint_latents": None,
|
||||
"control_latents": None,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"in_channels": 4,
|
||||
"mmdit_layers": 2,
|
||||
"num_layers": 2,
|
||||
"out_channels": 4,
|
||||
"patch_size": 2,
|
||||
"sample_height": 60,
|
||||
"sample_width": 90,
|
||||
"text_embed_dim": 16,
|
||||
"time_embed_dim": 8,
|
||||
"time_position_encoding_type": "3d_rope",
|
||||
"timestep_activation_fn": "silu",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"EasyAnimateTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@@ -30,6 +30,7 @@ class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = OmniGenTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
model_split_percents = [0.1, 0.1, 0.1]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -73,9 +74,9 @@ class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 4,
|
||||
"intermediate_size": 32,
|
||||
"num_layers": 1,
|
||||
"num_layers": 20,
|
||||
"pad_token_id": 0,
|
||||
"vocab_size": 100,
|
||||
"vocab_size": 1000,
|
||||
"in_channels": 4,
|
||||
"time_step_dim": 4,
|
||||
"rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
|
||||
|
||||
@@ -33,6 +33,7 @@ enable_full_determinism()
|
||||
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = SD3Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
model_split_percents = [0.8, 0.8, 0.9]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -67,7 +68,7 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"sample_size": 32,
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_layers": 4,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 4,
|
||||
"caption_projection_dim": 32,
|
||||
@@ -107,6 +108,7 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = SD3Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
model_split_percents = [0.8, 0.8, 0.9]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -141,7 +143,7 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"sample_size": 32,
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 2,
|
||||
"num_layers": 4,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 4,
|
||||
"caption_projection_dim": 32,
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import WanTransformer3DModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"text_dim": 16,
|
||||
"freq_dim": 256,
|
||||
"ffn_dim": 32,
|
||||
"num_layers": 2,
|
||||
"cross_attn_norm": True,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@@ -36,6 +36,9 @@ from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
is_peft_available,
|
||||
@@ -1002,7 +1005,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
@@ -1013,7 +1016,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
@@ -1024,7 +1027,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
@@ -1039,7 +1042,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
@@ -1054,7 +1057,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
@@ -1064,7 +1067,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
@@ -1164,11 +1167,11 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
return model
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_set_attention_slice_auto(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_max_memory_allocated(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice("auto")
|
||||
@@ -1180,15 +1183,15 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
mem_bytes = backend_max_memory_allocated(torch_device)
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_set_attention_slice_max(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_max_memory_allocated(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice("max")
|
||||
@@ -1200,15 +1203,15 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
mem_bytes = backend_max_memory_allocated(torch_device)
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_set_attention_slice_int(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_max_memory_allocated(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice(2)
|
||||
@@ -1220,15 +1223,15 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
mem_bytes = backend_max_memory_allocated(torch_device)
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_set_attention_slice_list(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_max_memory_allocated(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
# there are 32 sliceable layers
|
||||
slice_list = 16 * [2, 3]
|
||||
@@ -1242,7 +1245,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
mem_bytes = backend_max_memory_allocated(torch_device)
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
@@ -40,7 +40,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -245,7 +245,7 @@ class ControlNetPipelineSDXLFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_stable_diffusion_xl_offloads(self):
|
||||
pipes = []
|
||||
components = self.get_dummy_components()
|
||||
@@ -254,12 +254,12 @@ class ControlNetPipelineSDXLFastTests(
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_model_cpu_offload()
|
||||
sd_pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_sequential_cpu_offload()
|
||||
sd_pipe.enable_sequential_cpu_offload(device=torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
|
||||
@@ -223,12 +223,12 @@ class StableDiffusionXLControlNetPipelineFastTests(
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_model_cpu_offload()
|
||||
sd_pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_sequential_cpu_offload()
|
||||
sd_pipe.enable_sequential_cpu_offload(device=torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
|
||||
@@ -31,6 +31,7 @@ from diffusers import (
|
||||
from diffusers.models import FluxControlNetModel
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
@@ -39,13 +40,13 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
|
||||
pipeline_class = FluxControlNetPipeline
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
@@ -128,6 +129,8 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"controlnet": controlnet,
|
||||
"image_encoder": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
@@ -215,12 +218,12 @@ class FluxControlNetPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = FluxControlNetModel.from_pretrained(
|
||||
|
||||
@@ -158,30 +158,6 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi
|
||||
|
||||
assert max_diff > 1e-6
|
||||
|
||||
def test_flux_controlnet_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
|
||||
prompt,
|
||||
prompt_2=None,
|
||||
device=torch_device,
|
||||
max_sequence_length=inputs["max_sequence_length"],
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -239,7 +239,7 @@ class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
|
||||
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_model_cpu_offload(device=torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
|
||||
@@ -0,0 +1,294 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer, Qwen2VLForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLMagvit,
|
||||
EasyAnimatePipeline,
|
||||
EasyAnimateTransformer3DModel,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class EasyAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = EasyAnimatePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = EasyAnimateTransformer3DModel(
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
time_embed_dim=2,
|
||||
text_embed_dim=16, # Must match with tiny-random-t5
|
||||
num_layers=1,
|
||||
sample_width=16, # latent width: 2 -> final width: 16
|
||||
sample_height=16, # latent height: 2 -> final height: 16
|
||||
patch_size=2,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLMagvit(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=(
|
||||
"SpatialDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
"SpatialTemporalDownBlock3D",
|
||||
),
|
||||
up_block_types=(
|
||||
"SpatialUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
"SpatialTemporalUpBlock3D",
|
||||
),
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
latent_channels=4,
|
||||
layers_per_block=1,
|
||||
norm_num_groups=2,
|
||||
spatial_group_norm=False,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
text_encoder = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
|
||||
)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 16,
|
||||
"width": 16,
|
||||
"num_frames": 5,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (5, 3, 16, 16))
|
||||
expected_video = torch.randn(5, 3, 16, 16)
|
||||
max_diff = np.abs(generated_video - expected_video).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
|
||||
has_callback_step_end = "callback_on_step_end" in sig.parameters
|
||||
|
||||
if not (has_callback_tensor_inputs and has_callback_step_end):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
|
||||
# iterate over callback args
|
||||
for tensor_name, tensor_value in callback_kwargs.items():
|
||||
# check that we're only passing in allowed tensor inputs
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
# Test passing in a subset
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
# Test passing in a everything
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
|
||||
is_last = i == (pipe.num_timesteps - 1)
|
||||
if is_last:
|
||||
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
|
||||
return callback_kwargs
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_change_tensor
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
output = pipe(**inputs)[0]
|
||||
assert output.abs().sum() < 1e10
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self, expected_slice=None, expected_max_difference=0.001):
|
||||
# Seems to need a higher tolerance
|
||||
return super().test_dict_tuple_outputs_equivalent(expected_slice, expected_max_difference)
|
||||
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
# Seems to need a higher tolerance
|
||||
return super().test_encode_prompt_works_in_isolation(atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class EasyAnimatePipelineIntegrationTests(unittest.TestCase):
|
||||
prompt = "A painting of a squirrel eating a burger."
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_EasyAnimate(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
pipe = EasyAnimatePipeline.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
prompt = self.prompt
|
||||
|
||||
videos = pipe(
|
||||
prompt=prompt,
|
||||
height=480,
|
||||
width=720,
|
||||
num_frames=5,
|
||||
generator=generator,
|
||||
num_inference_steps=2,
|
||||
output_type="pt",
|
||||
).frames
|
||||
|
||||
video = videos[0]
|
||||
expected_video = torch.randn(1, 5, 480, 720, 3).numpy()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(video, expected_video)
|
||||
assert max_diff < 1e-3, f"Max diff is too high. got {video}"
|
||||
@@ -9,6 +9,7 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPToken
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
@@ -136,30 +137,6 @@ class FluxPipelineFastTests(
|
||||
# For some reasons, they don't show large differences
|
||||
assert max_diff > 1e-6
|
||||
|
||||
def test_flux_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
|
||||
prompt,
|
||||
prompt_2=None,
|
||||
device=torch_device,
|
||||
max_sequence_length=inputs["max_sequence_length"],
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
@@ -236,12 +213,12 @@ class FluxPipelineSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user