Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 07798babac |
@@ -9,43 +9,119 @@ permissions:
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
style:
|
||||
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
|
||||
with:
|
||||
python_quality_dependencies: "[quality]"
|
||||
pre_commit_script_name: "Download and Compare files from the main branch"
|
||||
pre_commit_script: |
|
||||
echo "Downloading the files from the main branch"
|
||||
run-style-bot:
|
||||
if: >
|
||||
contains(github.event.comment.body, '@bot /style') &&
|
||||
github.event.issue.pull_request != null
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
|
||||
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
|
||||
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
|
||||
steps:
|
||||
- name: Extract PR details
|
||||
id: pr_info
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = context.payload.issue.number;
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber
|
||||
});
|
||||
|
||||
// We capture both the branch ref and the "full_name" of the head repo
|
||||
// so that we can check out the correct repository & branch (including forks).
|
||||
core.setOutput("prNumber", prNumber);
|
||||
core.setOutput("headRef", pr.head.ref);
|
||||
core.setOutput("headRepoFullName", pr.head.repo.full_name);
|
||||
|
||||
echo "Compare the files and raise error if needed"
|
||||
- name: Check out PR branch
|
||||
uses: actions/checkout@v3
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
with:
|
||||
# Instead of checking out the base repo, use the contributor's repo name
|
||||
repository: ${{ env.HEADREPOFULLNAME }}
|
||||
ref: ${{ env.HEADREF }}
|
||||
# You may need fetch-depth: 0 for being able to push
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Debug
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
run: |
|
||||
echo "PR number: ${{ env.PRNUMBER }}"
|
||||
echo "Head Ref: ${{ env.HEADREF }}"
|
||||
echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}"
|
||||
|
||||
diff_failed=0
|
||||
if ! diff -q main_Makefile Makefile; then
|
||||
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
|
||||
if ! diff -q main_setup.py setup.py; then
|
||||
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install .[quality]
|
||||
|
||||
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
|
||||
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
- name: Download Makefile from main branch
|
||||
run: |
|
||||
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
|
||||
|
||||
- name: Compare Makefiles
|
||||
run: |
|
||||
if ! diff -q main_Makefile Makefile; then
|
||||
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "No changes in Makefile. Proceeding..."
|
||||
rm -rf main_Makefile
|
||||
|
||||
if [ $diff_failed -eq 1 ]; then
|
||||
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
|
||||
exit 1
|
||||
fi
|
||||
- name: Run make style and make quality
|
||||
run: |
|
||||
make style && make quality
|
||||
|
||||
echo "No changes in the files. Proceeding..."
|
||||
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
|
||||
style_command: "make style && make quality"
|
||||
secrets:
|
||||
bot_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Commit and push changes
|
||||
id: commit_and_push
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}"
|
||||
# Configure git with the Actions bot user
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Make sure your 'origin' remote is set to the contributor's fork
|
||||
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git"
|
||||
|
||||
# If there are changes after running style/quality, commit them
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
git add .
|
||||
git commit -m "Apply style fixes"
|
||||
# Push to the original contributor's forked branch
|
||||
git push origin HEAD:${{ env.HEADREF }}
|
||||
echo "changes_pushed=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No changes to commit."
|
||||
echo "changes_pushed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Comment on PR with workflow run link
|
||||
if: steps.commit_and_push.outputs.changes_pushed == 'true'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = parseInt(process.env.prNumber, 10);
|
||||
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
|
||||
});
|
||||
env:
|
||||
prNumber: ${{ steps.pr_info.outputs.prNumber }}
|
||||
|
||||
@@ -3,6 +3,7 @@ name: Fast tests for PRs
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
types: [synchronize]
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "benchmarks/**.py"
|
||||
|
||||
@@ -1,250 +0,0 @@
|
||||
name: Fast GPU Tests on PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: main
|
||||
paths:
|
||||
- "src/diffusers/models/modeling_utils.py"
|
||||
- "src/diffusers/models/model_loading_utils.py"
|
||||
- "src/diffusers/pipelines/pipeline_utils.py"
|
||||
- "src/diffusers/pipeline_loading_utils.py"
|
||||
- "src/diffusers/loaders/lora_base.py"
|
||||
- "src/diffusers/loaders/lora_pipeline.py"
|
||||
- "src/diffusers/loaders/peft.py"
|
||||
- "tests/pipelines/test_pipelines_common.py"
|
||||
- "tests/models/test_modeling_common.py"
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
|
||||
|
||||
jobs:
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
outputs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Fetch Pipeline Matrix
|
||||
id: fetch_pipeline_matrix
|
||||
run: |
|
||||
matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
|
||||
echo $matrix
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
|
||||
torch_pipelines_cuda_tests:
|
||||
name: Torch Pipelines CUDA Tests
|
||||
needs: setup_torch_cuda_pipeline_matrix
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 8
|
||||
matrix:
|
||||
module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Extract tests
|
||||
id: extract_tests
|
||||
run: |
|
||||
pattern=$(python utils/extract_tests_from_mixin.py --type pipeline)
|
||||
echo "$pattern" > /tmp/test_pattern.txt
|
||||
echo "pattern_file=/tmp/test_pattern.txt" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: PyTorch CUDA checkpoint tests on Ubuntu
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
else
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and $pattern" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
fi
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
|
||||
torch_cuda_tests:
|
||||
name: Torch CUDA Tests
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
module: [models, schedulers, lora, others]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Extract tests
|
||||
id: extract_tests
|
||||
run: |
|
||||
pattern=$(python utils/extract_tests_from_mixin.py --type ${{ matrix.module }})
|
||||
echo "$pattern" > /tmp/test_pattern.txt
|
||||
echo "pattern_file=/tmp/test_pattern.txt" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Run PyTorch CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
if [ -z "$pattern" ]; then
|
||||
python -m pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
else
|
||||
python -m pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
fi
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_torch_cuda_${{ matrix.module }}_stats.txt
|
||||
cat reports/tests_torch_cuda_${{ matrix.module }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
path: reports
|
||||
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install timm
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/examples_torch_cuda_stats.txt
|
||||
cat reports/examples_torch_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
name: Fast GPU Tests on main
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: main
|
||||
paths:
|
||||
- "src/diffusers/models/modeling_utils.py"
|
||||
- "src/diffusers/models/model_loading_utils.py"
|
||||
- "src/diffusers/pipelines/pipeline_utils.py"
|
||||
- "src/diffusers/pipeline_loading_utils.py"
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
@@ -160,6 +167,7 @@ jobs:
|
||||
path: reports
|
||||
|
||||
flax_tpu_tests:
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
name: Flax TPU Tests
|
||||
runs-on:
|
||||
group: gcp-ct5lp-hightpu-8t
|
||||
@@ -208,6 +216,7 @@ jobs:
|
||||
path: reports
|
||||
|
||||
onnx_cuda_tests:
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
name: ONNX CUDA Tests
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
@@ -256,6 +265,7 @@ jobs:
|
||||
path: reports
|
||||
|
||||
run_torch_compile_tests:
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
name: PyTorch Compile CUDA tests
|
||||
|
||||
runs-on:
|
||||
@@ -299,6 +309,7 @@ jobs:
|
||||
path: reports
|
||||
|
||||
run_xformers_tests:
|
||||
if: ${{ github.event_name != 'pull_request' }}
|
||||
name: PyTorch xformers CUDA tests
|
||||
|
||||
runs-on:
|
||||
|
||||
@@ -76,14 +76,6 @@
|
||||
- local: advanced_inference/outpaint
|
||||
title: Outpainting
|
||||
title: Advanced inference
|
||||
- sections:
|
||||
- local: hybrid_inference/overview
|
||||
title: Overview
|
||||
- local: hybrid_inference/vae_decode
|
||||
title: VAE Decode
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
title: Hybrid Inference
|
||||
- sections:
|
||||
- local: using-diffusers/cogvideox
|
||||
title: CogVideoX
|
||||
@@ -290,8 +282,6 @@
|
||||
title: CogView4Transformer2DModel
|
||||
- local: api/models/dit_transformer2d
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/easyanimate_transformer3d
|
||||
title: EasyAnimateTransformer3DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
@@ -324,8 +314,6 @@
|
||||
title: Transformer2DModel
|
||||
- local: api/models/transformer_temporal
|
||||
title: TransformerTemporalModel
|
||||
- local: api/models/wan_transformer_3d
|
||||
title: WanTransformer3DModel
|
||||
title: Transformers
|
||||
- sections:
|
||||
- local: api/models/stable_cascade_unet
|
||||
@@ -354,12 +342,8 @@
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
title: AutoencoderKLLTXVideo
|
||||
- local: api/models/autoencoderkl_magvit
|
||||
title: AutoencoderKLMagvit
|
||||
- local: api/models/autoencoderkl_mochi
|
||||
title: AutoencoderKLMochi
|
||||
- local: api/models/autoencoder_kl_wan
|
||||
title: AutoencoderKLWan
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_dc
|
||||
@@ -434,8 +418,6 @@
|
||||
title: DiffEdit
|
||||
- local: api/pipelines/dit
|
||||
title: DiT
|
||||
- local: api/pipelines/easyanimate
|
||||
title: EasyAnimate
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
@@ -552,8 +534,6 @@
|
||||
title: UniDiffuser
|
||||
- local: api/pipelines/value_guided_sampling
|
||||
title: Value-guided sampling
|
||||
- local: api/pipelines/wan
|
||||
title: Wan
|
||||
- local: api/pipelines/wuerstchen
|
||||
title: Wuerstchen
|
||||
title: Pipelines
|
||||
@@ -563,10 +543,6 @@
|
||||
title: Overview
|
||||
- local: api/schedulers/cm_stochastic_iterative
|
||||
title: CMStochasticIterativeScheduler
|
||||
- local: api/schedulers/ddim_cogvideox
|
||||
title: CogVideoXDDIMScheduler
|
||||
- local: api/schedulers/multistep_dpm_solver_cogvideox
|
||||
title: CogVideoXDPMScheduler
|
||||
- local: api/schedulers/consistency_decoder
|
||||
title: ConsistencyDecoderScheduler
|
||||
- local: api/schedulers/cosine_dpm
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLWan
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
|
||||
```
|
||||
|
||||
## AutoencoderKLWan
|
||||
|
||||
[[autodoc]] AutoencoderKLWan
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -1,37 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLMagvit
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLMagvit
|
||||
|
||||
vae = AutoencoderKLMagvit.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="vae", torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLMagvit
|
||||
|
||||
[[autodoc]] AutoencoderKLMagvit
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -1,30 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# EasyAnimateTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import EasyAnimateTransformer3DModel
|
||||
|
||||
transformer = EasyAnimateTransformer3DModel.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
## EasyAnimateTransformer3DModel
|
||||
|
||||
[[autodoc]] EasyAnimateTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -1,30 +0,0 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# WanTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data was introduced in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import WanTransformer3DModel
|
||||
|
||||
transformer = WanTransformer3DModel.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## WanTransformer3DModel
|
||||
|
||||
[[autodoc]] WanTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -1,88 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# EasyAnimate
|
||||
[EasyAnimate](https://github.com/aigc-apps/EasyAnimate) by Alibaba PAI.
|
||||
|
||||
The description from it's GitHub page:
|
||||
*EasyAnimate is a pipeline based on the transformer architecture, designed for generating AI images and videos, and for training baseline models and Lora models for Diffusion Transformer. We support direct prediction from pre-trained EasyAnimate models, allowing for the generation of videos with various resolutions, approximately 6 seconds in length, at 8fps (EasyAnimateV5.1, 1 to 49 frames). Additionally, users can train their own baseline and Lora models for specific style transformations.*
|
||||
|
||||
This pipeline was contributed by [bubbliiiing](https://github.com/bubbliiiing). The original codebase can be found [here](https://huggingface.co/alibaba-pai). The original weights can be found under [hf.co/alibaba-pai](https://huggingface.co/alibaba-pai).
|
||||
|
||||
There are two official EasyAnimate checkpoints for text-to-video and video-to-video.
|
||||
|
||||
| checkpoints | recommended inference dtype |
|
||||
|:---:|:---:|
|
||||
| [`alibaba-pai/EasyAnimateV5.1-12b-zh`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh) | torch.float16 |
|
||||
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
|
||||
|
||||
There is one official EasyAnimate checkpoints available for image-to-video and video-to-video.
|
||||
|
||||
| checkpoints | recommended inference dtype |
|
||||
|:---:|:---:|
|
||||
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
|
||||
|
||||
There are two official EasyAnimate checkpoints available for control-to-video.
|
||||
|
||||
| checkpoints | recommended inference dtype |
|
||||
|:---:|:---:|
|
||||
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control) | torch.float16 |
|
||||
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera) | torch.float16 |
|
||||
|
||||
For the EasyAnimateV5.1 series:
|
||||
- Text-to-video (T2V) and Image-to-video (I2V) works for multiple resolutions. The width and height can vary from 256 to 1024.
|
||||
- Both T2V and I2V models support generation with 1~49 frames and work best at this value. Exporting videos at 8 FPS is recommended.
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`EasyAnimatePipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = EasyAnimateTransformer3DModel.from_pretrained(
|
||||
"alibaba-pai/EasyAnimateV5.1-12b-zh",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = EasyAnimatePipeline.from_pretrained(
|
||||
"alibaba-pai/EasyAnimateV5.1-12b-zh",
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
prompt = "A cat walks on the grass, realistic style."
|
||||
negative_prompt = "bad detailed"
|
||||
video = pipeline(prompt=prompt, negative_prompt=negative_prompt, num_frames=49, num_inference_steps=30).frames[0]
|
||||
export_to_video(video, "cat.mp4", fps=8)
|
||||
```
|
||||
|
||||
## EasyAnimatePipeline
|
||||
|
||||
[[autodoc]] EasyAnimatePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## EasyAnimatePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.easyanimate.pipeline_output.EasyAnimatePipelineOutput
|
||||
@@ -359,74 +359,8 @@ image.save('flux_ip_adapter_output.jpg')
|
||||
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "wearing sunglasses"</figcaption>
|
||||
</div>
|
||||
|
||||
## Optimize
|
||||
|
||||
Flux is a very large model and requires ~50GB of RAM/VRAM to load all the modeling components. Enable some of the optimizations below to lower the memory requirements.
|
||||
|
||||
### Group offloading
|
||||
|
||||
[Group offloading](../../optimization/memory#group-offloading) lowers VRAM usage by offloading groups of internal layers rather than the whole model or weights. You need to use [`~hooks.apply_group_offloading`] on all the model components of a pipeline. The `offload_type` parameter allows you to toggle between block and leaf-level offloading. Setting it to `leaf_level` offloads the lowest leaf-level parameters to the CPU instead of offloading at the module-level.
|
||||
|
||||
On CUDA devices that support asynchronous data streaming, set `use_stream=True` to overlap data transfer and computation to accelerate inference.
|
||||
|
||||
> [!TIP]
|
||||
> It is possible to mix block and leaf-level offloading for different components in a pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
|
||||
model_id = "black-forest-labs/FLUX.1-dev"
|
||||
dtype = torch.bfloat16
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
apply_group_offloading(
|
||||
pipe.transformer,
|
||||
offload_type="leaf_level",
|
||||
offload_device=torch.device("cpu"),
|
||||
onload_device=torch.device("cuda"),
|
||||
use_stream=True,
|
||||
)
|
||||
apply_group_offloading(
|
||||
pipe.text_encoder,
|
||||
offload_device=torch.device("cpu"),
|
||||
onload_device=torch.device("cuda"),
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
)
|
||||
apply_group_offloading(
|
||||
pipe.text_encoder_2,
|
||||
offload_device=torch.device("cpu"),
|
||||
onload_device=torch.device("cuda"),
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
)
|
||||
apply_group_offloading(
|
||||
pipe.vae,
|
||||
offload_device=torch.device("cpu"),
|
||||
onload_device=torch.device("cuda"),
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
)
|
||||
|
||||
prompt="A cat wearing sunglasses and working as a lifeguard at pool."
|
||||
|
||||
generator = torch.Generator().manual_seed(181201)
|
||||
image = pipe(
|
||||
prompt,
|
||||
width=576,
|
||||
height=1024,
|
||||
num_inference_steps=30,
|
||||
generator=generator
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
### Running FP16 inference
|
||||
## Running FP16 inference
|
||||
|
||||
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
|
||||
|
||||
@@ -455,7 +389,7 @@ out = pipe(
|
||||
out.save("image.png")
|
||||
```
|
||||
|
||||
### Quantization
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
|
||||
@@ -49,8 +49,7 @@ The following models are available for the image-to-video pipeline:
|
||||
|
||||
| Model name | Description |
|
||||
|:---|:---|
|
||||
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
|
||||
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
|
||||
| [`https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
|
||||
|
||||
## Quantization
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
<!--
|
||||
Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
||||
Copyright 2024-2025 The HuggingFace Team. All rights reserved.
|
||||
<!--Copyright 2024 Marigold authors and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
@@ -12,120 +10,67 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Marigold Computer Vision
|
||||
# Marigold Pipelines for Computer Vision Tasks
|
||||
|
||||

|
||||
|
||||
Marigold was proposed in
|
||||
[Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation](https://huggingface.co/papers/2312.02145),
|
||||
a CVPR 2024 Oral paper by
|
||||
[Bingxin Ke](http://www.kebingxin.com/),
|
||||
[Anton Obukhov](https://www.obukhov.ai/),
|
||||
[Shengyu Huang](https://shengyuh.github.io/),
|
||||
[Nando Metzger](https://nandometzger.github.io/),
|
||||
[Rodrigo Caye Daudt](https://rcdaudt.github.io/), and
|
||||
[Konrad Schindler](https://scholar.google.com/citations?user=FZuNgqIAAAAJ&hl=en).
|
||||
The core idea is to **repurpose the generative prior of Text-to-Image Latent Diffusion Models (LDMs) for traditional
|
||||
computer vision tasks**.
|
||||
This approach was explored by fine-tuning Stable Diffusion for **Monocular Depth Estimation**, as demonstrated in the
|
||||
teaser above.
|
||||
Marigold was proposed in [Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation](https://huggingface.co/papers/2312.02145), a CVPR 2024 Oral paper by [Bingxin Ke](http://www.kebingxin.com/), [Anton Obukhov](https://www.obukhov.ai/), [Shengyu Huang](https://shengyuh.github.io/), [Nando Metzger](https://nandometzger.github.io/), [Rodrigo Caye Daudt](https://rcdaudt.github.io/), and [Konrad Schindler](https://scholar.google.com/citations?user=FZuNgqIAAAAJ&hl=en).
|
||||
The idea is to repurpose the rich generative prior of Text-to-Image Latent Diffusion Models (LDMs) for traditional computer vision tasks.
|
||||
Initially, this idea was explored to fine-tune Stable Diffusion for Monocular Depth Estimation, as shown in the teaser above.
|
||||
Later,
|
||||
- [Tianfu Wang](https://tianfwang.github.io/) trained the first Latent Consistency Model (LCM) of Marigold, which unlocked fast single-step inference;
|
||||
- [Kevin Qu](https://www.linkedin.com/in/kevin-qu-b3417621b/?locale=en_US) extended the approach to Surface Normals Estimation;
|
||||
- [Anton Obukhov](https://www.obukhov.ai/) contributed the pipelines and documentation into diffusers (enabled and supported by [YiYi Xu](https://yiyixuxu.github.io/) and [Sayak Paul](https://sayak.dev/)).
|
||||
|
||||
Marigold was later extended in the follow-up paper,
|
||||
[Marigold: Affordable Adaptation of Diffusion-Based Image Generators for Image Analysis](https://huggingface.co/papers/2312.02145),
|
||||
authored by
|
||||
[Bingxin Ke](http://www.kebingxin.com/),
|
||||
[Kevin Qu](https://www.linkedin.com/in/kevin-qu-b3417621b/?locale=en_US),
|
||||
[Tianfu Wang](https://tianfwang.github.io/),
|
||||
[Nando Metzger](https://nandometzger.github.io/),
|
||||
[Shengyu Huang](https://shengyuh.github.io/),
|
||||
[Bo Li](https://www.linkedin.com/in/bobboli0202/),
|
||||
[Anton Obukhov](https://www.obukhov.ai/), and
|
||||
[Konrad Schindler](https://scholar.google.com/citations?user=FZuNgqIAAAAJ&hl=en).
|
||||
This work expanded Marigold to support new modalities such as **Surface Normals** and **Intrinsic Image Decomposition**
|
||||
(IID), introduced a training protocol for **Latent Consistency Models** (LCM), and demonstrated **High-Resolution** (HR)
|
||||
processing capability.
|
||||
The abstract from the paper is:
|
||||
|
||||
<Tip>
|
||||
|
||||
The early Marigold models (`v1-0` and earlier) were optimized for best results with at least 10 inference steps.
|
||||
LCM models were later developed to enable high-quality inference in just 1 to 4 steps.
|
||||
Marigold models `v1-1` and later use the DDIM scheduler to achieve optimal
|
||||
results in as few as 1 to 4 steps.
|
||||
|
||||
</Tip>
|
||||
*Monocular depth estimation is a fundamental computer vision task. Recovering 3D depth from a single image is geometrically ill-posed and requires scene understanding, so it is not surprising that the rise of deep learning has led to a breakthrough. The impressive progress of monocular depth estimators has mirrored the growth in model capacity, from relatively modest CNNs to large Transformer architectures. Still, monocular depth estimators tend to struggle when presented with images with unfamiliar content and layout, since their knowledge of the visual world is restricted by the data seen during training, and challenged by zero-shot generalization to new domains. This motivates us to explore whether the extensive priors captured in recent generative diffusion models can enable better, more generalizable depth estimation. We introduce Marigold, a method for affine-invariant monocular depth estimation that is derived from Stable Diffusion and retains its rich prior knowledge. The estimator can be fine-tuned in a couple of days on a single GPU using only synthetic training data. It delivers state-of-the-art performance across a wide range of datasets, including over 20% performance gains in specific cases. Project page: https://marigoldmonodepth.github.io.*
|
||||
|
||||
## Available Pipelines
|
||||
|
||||
Each pipeline is tailored for a specific computer vision task, processing an input RGB image and generating a
|
||||
corresponding prediction.
|
||||
Currently, the following computer vision tasks are implemented:
|
||||
Each pipeline supports one Computer Vision task, which takes an input RGB image as input and produces a *prediction* of the modality of interest, such as a depth map of the input image.
|
||||
Currently, the following tasks are implemented:
|
||||
|
||||
| Pipeline | Predicted Modalities | Demos |
|
||||
|---------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------:|
|
||||
| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-lcm), [Slow Original Demo (DDIM)](https://huggingface.co/spaces/prs-eth/marigold) |
|
||||
| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-normals-lcm) |
|
||||
|
||||
| Pipeline | Recommended Model Checkpoints | Spaces (Interactive Apps) | Predicted Modalities |
|
||||
|---------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1) | [Depth Estimation](https://huggingface.co/spaces/prs-eth/marigold) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) |
|
||||
| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [prs-eth/marigold-normals-v1-1](https://huggingface.co/prs-eth/marigold-normals-v1-1) | [Surface Normals Estimation](https://huggingface.co/spaces/prs-eth/marigold-normals) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) |
|
||||
| [MarigoldIntrinsicsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py) | [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1),<br>[prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | [Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-iid) | [Albedo](https://en.wikipedia.org/wiki/Albedo), [Materials](https://www.n.aiq3d.com/wiki/roughnessmetalnessao-map), [Lighting](https://en.wikipedia.org/wiki/Diffuse_reflection) |
|
||||
|
||||
## Available Checkpoints
|
||||
|
||||
All original checkpoints are available under the [PRS-ETH](https://huggingface.co/prs-eth/) organization on Hugging Face.
|
||||
They are designed for use with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold), which can also be used to train
|
||||
new model checkpoints.
|
||||
The following is a summary of the recommended checkpoints, all of which produce reliable results with 1 to 4 steps.
|
||||
|
||||
| Checkpoint | Modality | Comment |
|
||||
|-----------------------------------------------------------------------------------------------------|--------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1) | Depth | Affine-invariant depth prediction assigns each pixel a value between 0 (near plane) and 1 (far plane), with both planes determined by the model during inference. |
|
||||
| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | The surface normals predictions are unit-length 3D vectors in the screen space camera, with values in the range from -1 to 1. |
|
||||
| [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1) | Intrinsics | InteriorVerse decomposition is comprised of Albedo and two BRDF material properties: Roughness and Metallicity. |
|
||||
| [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | Intrinsics | HyperSim decomposition of an image  \\(I\\)  is comprised of Albedo  \\(A\\), Diffuse shading  \\(S\\), and Non-diffuse residual  \\(R\\):  \\(I = A*S+R\\). |
|
||||
The original checkpoints can be found under the [PRS-ETH](https://huggingface.co/prs-eth/) Hugging Face organization.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff
|
||||
between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to
|
||||
efficiently load the same components into multiple pipelines.
|
||||
Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section
|
||||
[here](../../using-diffusers/svd#reduce-memory-usage).
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Marigold pipelines were designed and tested with the scheduler embedded in the model checkpoint.
|
||||
The optimal number of inference steps varies by scheduler, with no universal value that works best across all cases.
|
||||
To accommodate this, the `num_inference_steps` parameter in the pipeline's `__call__` method defaults to `None` (see the
|
||||
API reference).
|
||||
Unless set explicitly, it inherits the value from the `default_denoising_steps` field in the checkpoint configuration
|
||||
file (`model_index.json`).
|
||||
This ensures high-quality predictions when invoking the pipeline with only the `image` argument.
|
||||
Marigold pipelines were designed and tested only with `DDIMScheduler` and `LCMScheduler`.
|
||||
Depending on the scheduler, the number of inference steps required to get reliable predictions varies, and there is no universal value that works best across schedulers.
|
||||
Because of that, the default value of `num_inference_steps` in the `__call__` method of the pipeline is set to `None` (see the API reference).
|
||||
Unless set explicitly, its value will be taken from the checkpoint configuration `model_index.json`.
|
||||
This is done to ensure high-quality predictions when calling the pipeline with just the `image` argument.
|
||||
|
||||
</Tip>
|
||||
|
||||
See also Marigold [usage examples](../../using-diffusers/marigold_usage).
|
||||
|
||||
## Marigold Depth Prediction API
|
||||
See also Marigold [usage examples](marigold_usage).
|
||||
|
||||
## MarigoldDepthPipeline
|
||||
[[autodoc]] MarigoldDepthPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## MarigoldNormalsPipeline
|
||||
[[autodoc]] MarigoldNormalsPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## MarigoldDepthOutput
|
||||
[[autodoc]] pipelines.marigold.pipeline_marigold_depth.MarigoldDepthOutput
|
||||
|
||||
[[autodoc]] pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth
|
||||
|
||||
## Marigold Normals Estimation API
|
||||
[[autodoc]] MarigoldNormalsPipeline
|
||||
- __call__
|
||||
|
||||
[[autodoc]] pipelines.marigold.pipeline_marigold_normals.MarigoldNormalsOutput
|
||||
|
||||
[[autodoc]] pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals
|
||||
|
||||
## Marigold Intrinsic Image Decomposition API
|
||||
|
||||
[[autodoc]] MarigoldIntrinsicsPipeline
|
||||
- __call__
|
||||
|
||||
[[autodoc]] pipelines.marigold.pipeline_marigold_intrinsics.MarigoldIntrinsicsOutput
|
||||
|
||||
[[autodoc]] pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_intrinsics
|
||||
## MarigoldNormalsOutput
|
||||
[[autodoc]] pipelines.marigold.pipeline_marigold_normals.MarigoldNormalsOutput
|
||||
@@ -65,7 +65,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [Latte](latte) | text2image |
|
||||
| [LEDITS++](ledits_pp) | image editing |
|
||||
| [Lumina-T2X](lumina) | text2image |
|
||||
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
|
||||
| [Marigold](marigold) | depth |
|
||||
| [MultiDiffusion](panorama) | text2image |
|
||||
| [MusicLDM](musicldm) | text2audio |
|
||||
| [PAG](pag) | text2image |
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Wan
|
||||
|
||||
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
|
||||
|
||||
<!-- TODO(aryan): update abstract once paper is out -->
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
Recommendations for inference:
|
||||
- VAE in `torch.float32` for better decoding quality.
|
||||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
|
||||
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
|
||||
|
||||
### Using a custom scheduler
|
||||
|
||||
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
|
||||
|
||||
```python
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, WanPipeline
|
||||
|
||||
scheduler_a = FlowMatchEulerDiscreteScheduler(shift=5.0)
|
||||
scheduler_b = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=4.0)
|
||||
|
||||
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler=<CUSTOM_SCHEDULER_HERE>)
|
||||
|
||||
# or,
|
||||
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
|
||||
```
|
||||
|
||||
### Using single file loading with Wan
|
||||
|
||||
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
|
||||
method.
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import WanPipeline, WanTransformer3DModel
|
||||
|
||||
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
|
||||
transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
|
||||
```
|
||||
|
||||
## WanPipeline
|
||||
|
||||
[[autodoc]] WanPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanImageToVideoPipeline
|
||||
|
||||
[[autodoc]] WanImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
|
||||
@@ -1,19 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# CogVideoXDDIMScheduler
|
||||
|
||||
`CogVideoXDDIMScheduler` is based on [Denoising Diffusion Implicit Models](https://huggingface.co/papers/2010.02502), specifically for CogVideoX models.
|
||||
|
||||
## CogVideoXDDIMScheduler
|
||||
|
||||
[[autodoc]] CogVideoXDDIMScheduler
|
||||
@@ -1,19 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# CogVideoXDPMScheduler
|
||||
|
||||
`CogVideoXDPMScheduler` is based on [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models](https://huggingface.co/papers/2211.01095), specifically for CogVideoX models.
|
||||
|
||||
## CogVideoXDPMScheduler
|
||||
|
||||
[[autodoc]] CogVideoXDPMScheduler
|
||||
@@ -16,11 +16,6 @@ specific language governing permissions and limitations under the License.
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
</a>
|
||||
|
||||
> [!TIP]
|
||||
> This document has now grown outdated given the emergence of existing evaluation frameworks for diffusion models for image generation. Please check
|
||||
> out works like [HEIM](https://crfm.stanford.edu/helm/heim/latest/), [T2I-Compbench](https://arxiv.org/abs/2307.06350),
|
||||
> [GenEval](https://arxiv.org/abs/2310.11513).
|
||||
|
||||
Evaluation of generative models like [Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion) is subjective in nature. But as practitioners and researchers, we often have to make careful choices amongst many different possibilities. So, when working with different generative models (like GANs, Diffusion, etc.), how do we choose one over the other?
|
||||
|
||||
Qualitative evaluation of such models can be error-prone and might incorrectly influence a decision.
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
# Hybrid Inference API Reference
|
||||
|
||||
## Remote Decode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_decode
|
||||
@@ -1,54 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Hybrid Inference
|
||||
|
||||
**Empowering local AI builders with Hybrid Inference**
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae).
|
||||
> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
|
||||
## Why use Hybrid Inference?
|
||||
|
||||
Hybrid Inference offers a fast and simple way to offload local generation requirements.
|
||||
|
||||
- 🚀 **Reduced Requirements:** Access powerful models without expensive hardware.
|
||||
- 💎 **Without Compromise:** Achieve the highest quality without sacrificing performance.
|
||||
- 💰 **Cost Effective:** It's free! 🤑
|
||||
- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community.
|
||||
- 🔧 **Developer-Friendly:** Simple requests, fast responses.
|
||||
|
||||
---
|
||||
|
||||
## Available Models
|
||||
|
||||
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
|
||||
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
|
||||
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
|
||||
|
||||
---
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
|
||||
## Contents
|
||||
|
||||
The documentation is organized into two sections:
|
||||
|
||||
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
|
||||
* **API Reference** Dive into task-specific settings and parameters.
|
||||
@@ -1,345 +0,0 @@
|
||||
# Getting Started: VAE Decode with Hybrid Inference
|
||||
|
||||
VAE decode is an essential component of diffusion models - turning latent representations into images or videos.
|
||||
|
||||
## Memory
|
||||
|
||||
These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs.
|
||||
|
||||
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality.
|
||||
|
||||
<details><summary>SD v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
|
||||
|
||||
</details>
|
||||
|
||||
## Available VAEs
|
||||
|
||||
| | **Endpoint** | **Model** |
|
||||
|:-:|:-----------:|:--------:|
|
||||
| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
|
||||
| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
|
||||
| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
|
||||
| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
## Code
|
||||
|
||||
> [!TIP]
|
||||
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
|
||||
|
||||
|
||||
A helper method simplifies interacting with Hybrid Inference.
|
||||
|
||||
```python
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
```
|
||||
|
||||
### Basic example
|
||||
|
||||
Here, we show how to use the remote VAE on random tensors.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/output.png"/>
|
||||
</figure>
|
||||
|
||||
Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
image = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
|
||||
height=1024,
|
||||
width=1024,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/flux_random_latent.png"/>
|
||||
</figure>
|
||||
|
||||
Finally, an example for HunyuanVideo.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
video = remote_decode(
|
||||
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16),
|
||||
output_type="mp4",
|
||||
)
|
||||
with open("video.mp4", "wb") as f:
|
||||
f.write(video)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video_1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
|
||||
### Generation
|
||||
|
||||
But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
|
||||
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
image.save("test.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test.jpg"/>
|
||||
</figure>
|
||||
|
||||
Here’s another example with Flux.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16,
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
|
||||
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
guidance_scale=0.0,
|
||||
num_inference_steps=4,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
height=1024,
|
||||
width=1024,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
image.save("test.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test_1.jpg"/>
|
||||
</figure>
|
||||
|
||||
Here’s an example with HunyuanVideo.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
|
||||
model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
latent = pipe(
|
||||
prompt="A cat walks on the grass, realistic",
|
||||
height=320,
|
||||
width=512,
|
||||
num_frames=61,
|
||||
num_inference_steps=30,
|
||||
output_type="latent",
|
||||
).frames
|
||||
|
||||
video = remote_decode(
|
||||
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
output_type="mp4",
|
||||
)
|
||||
|
||||
if isinstance(video, bytes):
|
||||
with open("video.mp4", "wb") as f:
|
||||
f.write(video)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
|
||||
### Queueing
|
||||
|
||||
One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency.
|
||||
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
import queue
|
||||
import threading
|
||||
from IPython.display import display
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
def decode_worker(q: queue.Queue):
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
break
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=item,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
display(image)
|
||||
q.task_done()
|
||||
|
||||
q = queue.Queue()
|
||||
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
def decode(latent: torch.Tensor):
|
||||
q.put(latent)
|
||||
|
||||
prompts = [
|
||||
"Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious",
|
||||
"Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore",
|
||||
"Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.",
|
||||
"Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP",
|
||||
"A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting",
|
||||
"Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,",
|
||||
]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"Lykon/dreamshaper-8",
|
||||
torch_dtype=torch.float16,
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
_ = pipe(
|
||||
prompt=prompts[0],
|
||||
output_type="latent",
|
||||
)
|
||||
|
||||
for prompt in prompts:
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
output_type="latent",
|
||||
).images
|
||||
decode(latent)
|
||||
|
||||
q.put(None)
|
||||
thread.join()
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/queue.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
@@ -157,84 +157,6 @@ pipeline(
|
||||
)
|
||||
```
|
||||
|
||||
## IP Adapter Cutoff
|
||||
|
||||
IP Adapter is an image prompt adapter that can be used for diffusion models without any changes to the underlying model. We can use the IP Adapter Cutoff Callback to disable the IP Adapter after a certain number of steps. To set up the callback, you need to specify the number of denoising steps after which the callback comes into effect. You can do so by using either one of these two arguments:
|
||||
|
||||
- `cutoff_step_ratio`: Float number with the ratio of the steps.
|
||||
- `cutoff_step_index`: Integer number with the exact number of the step.
|
||||
|
||||
We need to download the diffusion model and load the ip_adapter for it as follows:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
pipeline.set_ip_adapter_scale(0.6)
|
||||
```
|
||||
The setup for the callback should look something like this:
|
||||
|
||||
```py
|
||||
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.callbacks import IPAdapterScaleCutoffCallback
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
|
||||
pipeline.load_ip_adapter(
|
||||
"h94/IP-Adapter",
|
||||
subfolder="sdxl_models",
|
||||
weight_name="ip-adapter_sdxl.bin"
|
||||
)
|
||||
|
||||
pipeline.set_ip_adapter_scale(0.6)
|
||||
|
||||
|
||||
callback = IPAdapterScaleCutoffCallback(
|
||||
cutoff_step_ratio=None,
|
||||
cutoff_step_index=5
|
||||
)
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png"
|
||||
)
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(2628670641)
|
||||
|
||||
images = pipeline(
|
||||
prompt="a tiger sitting in a chair drinking orange juice",
|
||||
ip_adapter_image=image,
|
||||
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
|
||||
generator=generator,
|
||||
num_inference_steps=50,
|
||||
callback_on_step_end=callback,
|
||||
).images
|
||||
|
||||
images[0].save("custom_callback_img.png")
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/without_callback.png" alt="generated image of a tiger sitting in a chair drinking orange juice" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">without IPAdapterScaleCutoffCallback</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_callback2.png" alt="generated image of a tiger sitting in a chair drinking orange juice with ip adapter callback" />
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">with IPAdapterScaleCutoffCallback</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
## Display image after each generation step
|
||||
|
||||
> [!TIP]
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
<!--
|
||||
Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
|
||||
Copyright 2024-2025 The HuggingFace Team. All rights reserved.
|
||||
<!--Copyright 2024 Marigold authors and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
@@ -12,38 +10,31 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Marigold Computer Vision
|
||||
# Marigold Pipelines for Computer Vision Tasks
|
||||
|
||||
**Marigold** is a diffusion-based [method](https://huggingface.co/papers/2312.02145) and a collection of [pipelines](../api/pipelines/marigold) designed for
|
||||
dense computer vision tasks, including **monocular depth prediction**, **surface normals estimation**, and **intrinsic
|
||||
image decomposition**.
|
||||
[Marigold](../api/pipelines/marigold) is a novel diffusion-based dense prediction approach, and a set of pipelines for various computer vision tasks, such as monocular depth estimation.
|
||||
|
||||
This guide will walk you through using Marigold to generate fast and high-quality predictions for images and videos.
|
||||
This guide will show you how to use Marigold to obtain fast and high-quality predictions for images and videos.
|
||||
|
||||
Each pipeline is tailored for a specific computer vision task, processing an input RGB image and generating a
|
||||
corresponding prediction.
|
||||
Currently, the following computer vision tasks are implemented:
|
||||
Each pipeline supports one Computer Vision task, which takes an input RGB image as input and produces a *prediction* of the modality of interest, such as a depth map of the input image.
|
||||
Currently, the following tasks are implemented:
|
||||
|
||||
| Pipeline | Recommended Model Checkpoints | Spaces (Interactive Apps) | Predicted Modalities |
|
||||
|---------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1) | [Depth Estimation](https://huggingface.co/spaces/prs-eth/marigold) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) |
|
||||
| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [prs-eth/marigold-normals-v1-1](https://huggingface.co/prs-eth/marigold-normals-v1-1) | [Surface Normals Estimation](https://huggingface.co/spaces/prs-eth/marigold-normals) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) |
|
||||
| [MarigoldIntrinsicsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py) | [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1),<br>[prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | [Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-iid) | [Albedo](https://en.wikipedia.org/wiki/Albedo), [Materials](https://www.n.aiq3d.com/wiki/roughnessmetalnessao-map), [Lighting](https://en.wikipedia.org/wiki/Diffuse_reflection) |
|
||||
| Pipeline | Predicted Modalities | Demos |
|
||||
|---------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------:|
|
||||
| [MarigoldDepthPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py) | [Depth](https://en.wikipedia.org/wiki/Depth_map), [Disparity](https://en.wikipedia.org/wiki/Binocular_disparity) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-lcm), [Slow Original Demo (DDIM)](https://huggingface.co/spaces/prs-eth/marigold) |
|
||||
| [MarigoldNormalsPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py) | [Surface normals](https://en.wikipedia.org/wiki/Normal_mapping) | [Fast Demo (LCM)](https://huggingface.co/spaces/prs-eth/marigold-normals-lcm) |
|
||||
|
||||
All original checkpoints are available under the [PRS-ETH](https://huggingface.co/prs-eth/) organization on Hugging Face.
|
||||
They are designed for use with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold), which can also be used to train
|
||||
new model checkpoints.
|
||||
The following is a summary of the recommended checkpoints, all of which produce reliable results with 1 to 4 steps.
|
||||
The original checkpoints can be found under the [PRS-ETH](https://huggingface.co/prs-eth/) Hugging Face organization.
|
||||
These checkpoints are meant to work with diffusers pipelines and the [original codebase](https://github.com/prs-eth/marigold).
|
||||
The original code can also be used to train new checkpoints.
|
||||
|
||||
| Checkpoint | Modality | Comment |
|
||||
|-----------------------------------------------------------------------------------------------------|--------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1) | Depth | Affine-invariant depth prediction assigns each pixel a value between 0 (near plane) and 1 (far plane), with both planes determined by the model during inference. |
|
||||
| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | The surface normals predictions are unit-length 3D vectors in the screen space camera, with values in the range from -1 to 1. |
|
||||
| [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1) | Intrinsics | InteriorVerse decomposition is comprised of Albedo and two BRDF material properties: Roughness and Metallicity. |
|
||||
| [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | Intrinsics | HyperSim decomposition of an image \\(I\\) is comprised of Albedo \\(A\\), Diffuse shading \\(S\\), and Non-diffuse residual \\(R\\): \\(I = A*S+R\\). |
|
||||
|
||||
The examples below are mostly given for depth prediction, but they can be universally applied to other supported
|
||||
modalities.
|
||||
| Checkpoint | Modality | Comment |
|
||||
|-----------------------------------------------------------------------------------------------|----------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [prs-eth/marigold-v1-0](https://huggingface.co/prs-eth/marigold-v1-0) | Depth | The first Marigold Depth checkpoint, which predicts *affine-invariant depth* maps. The performance of this checkpoint in benchmarks was studied in the original [paper](https://huggingface.co/papers/2312.02145). Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. Affine-invariant depth prediction has a range of values in each pixel between 0 (near plane) and 1 (far plane); both planes are chosen by the model as part of the inference process. See the `MarigoldImageProcessor` reference for visualization utilities. |
|
||||
| [prs-eth/marigold-depth-lcm-v1-0](https://huggingface.co/prs-eth/marigold-depth-lcm-v1-0) | Depth | The fast Marigold Depth checkpoint, fine-tuned from `prs-eth/marigold-v1-0`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. |
|
||||
| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | A preview checkpoint for the Marigold Normals pipeline. Designed to be used with the `DDIMScheduler` at inference, it requires at least 10 steps to get reliable predictions. The surface normals predictions are unit-length 3D vectors with values in the range from -1 to 1. *This checkpoint will be phased out after the release of `v1-0` version.* |
|
||||
| [prs-eth/marigold-normals-lcm-v0-1](https://huggingface.co/prs-eth/marigold-normals-lcm-v0-1) | Normals | The fast Marigold Normals checkpoint, fine-tuned from `prs-eth/marigold-normals-v0-1`. Designed to be used with the `LCMScheduler` at inference, it requires as little as 1 step to get reliable predictions. The prediction reliability saturates at 4 steps and declines after that. *This checkpoint will be phased out after the release of `v1-0` version.* |
|
||||
The examples below are mostly given for depth prediction, but they can be universally applied with other supported modalities.
|
||||
We showcase the predictions using the same input image of Albert Einstein generated by Midjourney.
|
||||
This makes it easier to compare visualizations of the predictions across various modalities and checkpoints.
|
||||
|
||||
@@ -56,21 +47,19 @@ This makes it easier to compare visualizations of the predictions across various
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Depth Prediction
|
||||
### Depth Prediction Quick Start
|
||||
|
||||
To get a depth prediction, load the `prs-eth/marigold-depth-v1-1` checkpoint into [`MarigoldDepthPipeline`],
|
||||
put the image through the pipeline, and save the predictions:
|
||||
To get the first depth prediction, load `prs-eth/marigold-depth-lcm-v1-0` checkpoint into `MarigoldDepthPipeline` pipeline, put the image through the pipeline, and save the predictions:
|
||||
|
||||
```python
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
||||
"prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
|
||||
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
depth = pipe(image)
|
||||
|
||||
vis = pipe.image_processor.visualize_depth(depth.prediction)
|
||||
@@ -80,13 +69,10 @@ depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction)
|
||||
depth_16bit[0].save("einstein_depth_16bit.png")
|
||||
```
|
||||
|
||||
The [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth`] function applies one of
|
||||
[matplotlib's colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html) (`Spectral` by default) to map the predicted pixel values from a single-channel `[0, 1]`
|
||||
depth range into an RGB image.
|
||||
With the `Spectral` colormap, pixels with near depth are painted red, and far pixels are blue.
|
||||
The visualization function for depth [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_depth`] applies one of [matplotlib's colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html) (`Spectral` by default) to map the predicted pixel values from a single-channel `[0, 1]` depth range into an RGB image.
|
||||
With the `Spectral` colormap, pixels with near depth are painted red, and far pixels are assigned blue color.
|
||||
The 16-bit PNG file stores the single channel values mapped linearly from the `[0, 1]` range into `[0, 65535]`.
|
||||
Below are the raw and the visualized predictions. The darker and closer areas (mustache) are easier to distinguish in
|
||||
the visualization.
|
||||
Below are the raw and the visualized predictions; as can be seen, dark areas (mustache) are easier to distinguish in the visualization:
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
@@ -103,33 +89,28 @@ the visualization.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Surface Normals Estimation
|
||||
### Surface Normals Prediction Quick Start
|
||||
|
||||
Load the `prs-eth/marigold-normals-v1-1` checkpoint into [`MarigoldNormalsPipeline`], put the image through the
|
||||
pipeline, and save the predictions:
|
||||
Load `prs-eth/marigold-normals-lcm-v0-1` checkpoint into `MarigoldNormalsPipeline` pipeline, put the image through the pipeline, and save the predictions:
|
||||
|
||||
```python
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(
|
||||
"prs-eth/marigold-normals-v1-1", variant="fp16", torch_dtype=torch.float16
|
||||
"prs-eth/marigold-normals-lcm-v0-1", variant="fp16", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
normals = pipe(image)
|
||||
|
||||
vis = pipe.image_processor.visualize_normals(normals.prediction)
|
||||
vis[0].save("einstein_normals.png")
|
||||
```
|
||||
|
||||
The [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals`] maps the three-dimensional
|
||||
prediction with pixel values in the range `[-1, 1]` into an RGB image.
|
||||
The visualization function supports flipping surface normals axes to make the visualization compatible with other
|
||||
choices of the frame of reference.
|
||||
Conceptually, each pixel is painted according to the surface normal vector in the frame of reference, where `X` axis
|
||||
points right, `Y` axis points up, and `Z` axis points at the viewer.
|
||||
The visualization function for normals [`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_normals`] maps the three-dimensional prediction with pixel values in the range `[-1, 1]` into an RGB image.
|
||||
The visualization function supports flipping surface normals axes to make the visualization compatible with other choices of the frame of reference.
|
||||
Conceptually, each pixel is painted according to the surface normal vector in the frame of reference, where `X` axis points right, `Y` axis points up, and `Z` axis points at the viewer.
|
||||
Below is the visualized prediction:
|
||||
|
||||
<div class="flex gap-4" style="justify-content: center; width: 100%;">
|
||||
@@ -141,121 +122,25 @@ Below is the visualized prediction:
|
||||
</div>
|
||||
</div>
|
||||
|
||||
In this example, the nose tip almost certainly has a point on the surface, in which the surface normal vector points
|
||||
straight at the viewer, meaning that its coordinates are `[0, 0, 1]`.
|
||||
In this example, the nose tip almost certainly has a point on the surface, in which the surface normal vector points straight at the viewer, meaning that its coordinates are `[0, 0, 1]`.
|
||||
This vector maps to the RGB `[128, 128, 255]`, which corresponds to the violet-blue color.
|
||||
Similarly, a surface normal on the cheek in the right part of the image has a large `X` component, which increases the
|
||||
red hue.
|
||||
Similarly, a surface normal on the cheek in the right part of the image has a large `X` component, which increases the red hue.
|
||||
Points on the shoulders pointing up with a large `Y` promote green color.
|
||||
|
||||
## Intrinsic Image Decomposition
|
||||
### Speeding up inference
|
||||
|
||||
Marigold provides two models for Intrinsic Image Decomposition (IID): "Appearance" and "Lighting".
|
||||
Each model produces Albedo maps, derived from InteriorVerse and Hypersim annotations, respectively.
|
||||
|
||||
- The "Appearance" model also estimates Material properties: Roughness and Metallicity.
|
||||
- The "Lighting" model generates Diffuse Shading and Non-diffuse Residual.
|
||||
|
||||
Here is the sample code saving predictions made by the "Appearance" model:
|
||||
|
||||
```python
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
pipe = diffusers.MarigoldIntrinsicsPipeline.from_pretrained(
|
||||
"prs-eth/marigold-iid-appearance-v1-1", variant="fp16", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
intrinsics = pipe(image)
|
||||
|
||||
vis = pipe.image_processor.visualize_intrinsics(intrinsics.prediction, pipe.target_properties)
|
||||
vis[0]["albedo"].save("einstein_albedo.png")
|
||||
vis[0]["roughness"].save("einstein_roughness.png")
|
||||
vis[0]["metallicity"].save("einstein_metallicity.png")
|
||||
```
|
||||
|
||||
Another example demonstrating the predictions made by the "Lighting" model:
|
||||
|
||||
```python
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
pipe = diffusers.MarigoldIntrinsicsPipeline.from_pretrained(
|
||||
"prs-eth/marigold-iid-lighting-v1-1", variant="fp16", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
intrinsics = pipe(image)
|
||||
|
||||
vis = pipe.image_processor.visualize_intrinsics(intrinsics.prediction, pipe.target_properties)
|
||||
vis[0]["albedo"].save("einstein_albedo.png")
|
||||
vis[0]["shading"].save("einstein_shading.png")
|
||||
vis[0]["residual"].save("einstein_residual.png")
|
||||
```
|
||||
|
||||
Both models share the same pipeline while supporting different decomposition types.
|
||||
The exact decomposition parameterization (e.g., sRGB vs. linear space) is stored in the
|
||||
`pipe.target_properties` dictionary, which is passed into the
|
||||
[`~pipelines.marigold.marigold_image_processing.MarigoldImageProcessor.visualize_intrinsics`] function.
|
||||
|
||||
Below are some examples showcasing the predicted decomposition outputs.
|
||||
All modalities can be inspected in the
|
||||
[Intrinsic Image Decomposition](https://huggingface.co/spaces/prs-eth/marigold-iid) Space.
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/8c7986eaaab5eb9604eb88336311f46a7b0ff5ab/marigold/marigold_einstein_albedo.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Predicted albedo ("Appearance" model)
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/8c7986eaaab5eb9604eb88336311f46a7b0ff5ab/marigold/marigold_einstein_diffuse.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Predicted diffuse shading ("Lighting" model)
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Speeding up inference
|
||||
|
||||
The above quick start snippets are already optimized for quality and speed, loading the checkpoint, utilizing the
|
||||
`fp16` variant of weights and computation, and performing the default number (4) of denoising diffusion steps.
|
||||
The first step to accelerate inference, at the expense of prediction quality, is to reduce the denoising diffusion
|
||||
steps to the minimum:
|
||||
The above quick start snippets are already optimized for speed: they load the LCM checkpoint, use the `fp16` variant of weights and computation, and perform just one denoising diffusion step.
|
||||
The `pipe(image)` call completes in 280ms on RTX 3090 GPU.
|
||||
Internally, the input image is encoded with the Stable Diffusion VAE encoder, then the U-Net performs one denoising step, and finally, the prediction latent is decoded with the VAE decoder into pixel space.
|
||||
In this case, two out of three module calls are dedicated to converting between pixel and latent space of LDM.
|
||||
Because Marigold's latent space is compatible with the base Stable Diffusion, it is possible to speed up the pipeline call by more than 3x (85ms on RTX 3090) by using a [lightweight replacement of the SD VAE](../api/models/autoencoder_tiny):
|
||||
|
||||
```diff
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
||||
"prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
- depth = pipe(image)
|
||||
+ depth = pipe(image, num_inference_steps=1)
|
||||
```
|
||||
|
||||
With this change, the `pipe` call completes in 280ms on RTX 3090 GPU.
|
||||
Internally, the input image is first encoded using the Stable Diffusion VAE encoder, followed by a single denoising
|
||||
step performed by the U-Net.
|
||||
Finally, the prediction latent is decoded with the VAE decoder into pixel space.
|
||||
In this setup, two out of three module calls are dedicated to converting between the pixel and latent spaces of the LDM.
|
||||
Since Marigold's latent space is compatible with Stable Diffusion 2.0, inference can be accelerated by more than 3x,
|
||||
reducing the call time to 85ms on an RTX 3090, by using a [lightweight replacement of the SD VAE](../api/models/autoencoder_tiny).
|
||||
Note that using a lightweight VAE may slightly reduce the visual quality of the predictions.
|
||||
|
||||
```diff
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
||||
"prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
|
||||
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
+ pipe.vae = diffusers.AutoencoderTiny.from_pretrained(
|
||||
@@ -263,77 +148,78 @@ Note that using a lightweight VAE may slightly reduce the visual quality of the
|
||||
+ ).cuda()
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
depth = pipe(image, num_inference_steps=1)
|
||||
depth = pipe(image)
|
||||
```
|
||||
|
||||
So far, we have optimized the number of diffusion steps and model components. Self-attention operations account for a
|
||||
significant portion of computations.
|
||||
Speeding them up can be achieved by using a more efficient attention processor:
|
||||
As suggested in [Optimizations](../optimization/torch2.0#torch.compile), adding `torch.compile` may squeeze extra performance depending on the target hardware:
|
||||
|
||||
```diff
|
||||
import diffusers
|
||||
import torch
|
||||
+ from diffusers.models.attention_processor import AttnProcessor2_0
|
||||
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
||||
"prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
|
||||
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
+ pipe.vae.set_attn_processor(AttnProcessor2_0())
|
||||
+ pipe.unet.set_attn_processor(AttnProcessor2_0())
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
depth = pipe(image, num_inference_steps=1)
|
||||
```
|
||||
|
||||
Finally, as suggested in [Optimizations](../optimization/torch2.0#torch.compile), enabling `torch.compile` can further enhance performance depending on
|
||||
the target hardware.
|
||||
However, compilation incurs a significant overhead during the first pipeline invocation, making it beneficial only when
|
||||
the same pipeline instance is called repeatedly, such as within a loop.
|
||||
|
||||
```diff
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0
|
||||
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
||||
"prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
pipe.vae.set_attn_processor(AttnProcessor2_0())
|
||||
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
||||
|
||||
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
|
||||
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
depth = pipe(image, num_inference_steps=1)
|
||||
depth = pipe(image)
|
||||
```
|
||||
|
||||
## Qualitative Comparison with Depth Anything
|
||||
|
||||
With the above speed optimizations, Marigold delivers predictions with more details and faster than [Depth Anything](https://huggingface.co/docs/transformers/main/en/model_doc/depth_anything) with the largest checkpoint [LiheYoung/depth-anything-large-hf](https://huggingface.co/LiheYoung/depth-anything-large-hf):
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_lcm_depth.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Marigold LCM fp16 with Tiny AutoEncoder
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/einstein_depthanything_large.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Depth Anything Large
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Maximizing Precision and Ensembling
|
||||
|
||||
Marigold pipelines have a built-in ensembling mechanism combining multiple predictions from different random latents.
|
||||
This is a brute-force way of improving the precision of predictions, capitalizing on the generative nature of diffusion.
|
||||
The ensembling path is activated automatically when the `ensemble_size` argument is set greater or equal than `3`.
|
||||
The ensembling path is activated automatically when the `ensemble_size` argument is set greater than `1`.
|
||||
When aiming for maximum precision, it makes sense to adjust `num_inference_steps` simultaneously with `ensemble_size`.
|
||||
The recommended values vary across checkpoints but primarily depend on the scheduler type.
|
||||
The effect of ensembling is particularly well-seen with surface normals:
|
||||
|
||||
```diff
|
||||
import diffusers
|
||||
```python
|
||||
import diffusers
|
||||
|
||||
pipe = diffusers.MarigoldNormalsPipeline.from_pretrained("prs-eth/marigold-normals-v1-1").to("cuda")
|
||||
model_path = "prs-eth/marigold-normals-v1-0"
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
model_paper_kwargs = {
|
||||
diffusers.schedulers.DDIMScheduler: {
|
||||
"num_inference_steps": 10,
|
||||
"ensemble_size": 10,
|
||||
},
|
||||
diffusers.schedulers.LCMScheduler: {
|
||||
"num_inference_steps": 4,
|
||||
"ensemble_size": 5,
|
||||
},
|
||||
}
|
||||
|
||||
- depth = pipe(image)
|
||||
+ depth = pipe(image, num_inference_steps=10, ensemble_size=5)
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
vis = pipe.image_processor.visualize_normals(depth.prediction)
|
||||
vis[0].save("einstein_normals.png")
|
||||
pipe = diffusers.MarigoldNormalsPipeline.from_pretrained(model_path).to("cuda")
|
||||
pipe_kwargs = model_paper_kwargs[type(pipe.scheduler)]
|
||||
|
||||
depth = pipe(image, **pipe_kwargs)
|
||||
|
||||
vis = pipe.image_processor.visualize_normals(depth.prediction)
|
||||
vis[0].save("einstein_normals.png")
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
@@ -351,16 +237,93 @@ The effect of ensembling is particularly well-seen with surface normals:
|
||||
</div>
|
||||
</div>
|
||||
|
||||
As can be seen, all areas with fine-grained structurers, such as hair, got more conservative and on average more
|
||||
correct predictions.
|
||||
As can be seen, all areas with fine-grained structurers, such as hair, got more conservative and on average more correct predictions.
|
||||
Such a result is more suitable for precision-sensitive downstream tasks, such as 3D reconstruction.
|
||||
|
||||
## Quantitative Evaluation
|
||||
|
||||
To evaluate Marigold quantitatively in standard leaderboards and benchmarks (such as NYU, KITTI, and other datasets), follow the evaluation protocol outlined in the paper: load the full precision fp32 model and use appropriate values for `num_inference_steps` and `ensemble_size`.
|
||||
Optionally seed randomness to ensure reproducibility. Maximizing `batch_size` will deliver maximum device utilization.
|
||||
|
||||
```python
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
device = "cuda"
|
||||
seed = 2024
|
||||
model_path = "prs-eth/marigold-v1-0"
|
||||
|
||||
model_paper_kwargs = {
|
||||
diffusers.schedulers.DDIMScheduler: {
|
||||
"num_inference_steps": 50,
|
||||
"ensemble_size": 10,
|
||||
},
|
||||
diffusers.schedulers.LCMScheduler: {
|
||||
"num_inference_steps": 4,
|
||||
"ensemble_size": 10,
|
||||
},
|
||||
}
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(model_path).to(device)
|
||||
pipe_kwargs = model_paper_kwargs[type(pipe.scheduler)]
|
||||
|
||||
depth = pipe(image, generator=generator, **pipe_kwargs)
|
||||
|
||||
# evaluate metrics
|
||||
```
|
||||
|
||||
## Using Predictive Uncertainty
|
||||
|
||||
The ensembling mechanism built into Marigold pipelines combines multiple predictions obtained from different random latents.
|
||||
As a side effect, it can be used to quantify epistemic (model) uncertainty; simply specify `ensemble_size` greater than 1 and set `output_uncertainty=True`.
|
||||
The resulting uncertainty will be available in the `uncertainty` field of the output.
|
||||
It can be visualized as follows:
|
||||
|
||||
```python
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
||||
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
depth = pipe(
|
||||
image,
|
||||
ensemble_size=10, # any number greater than 1; higher values yield higher precision
|
||||
output_uncertainty=True,
|
||||
)
|
||||
|
||||
uncertainty = pipe.image_processor.visualize_uncertainty(depth.uncertainty)
|
||||
uncertainty[0].save("einstein_depth_uncertainty.png")
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_depth_uncertainty.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Depth uncertainty
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_normals_uncertainty.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Surface normals uncertainty
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
The interpretation of uncertainty is easy: higher values (white) correspond to pixels, where the model struggles to make consistent predictions.
|
||||
Evidently, the depth model is the least confident around edges with discontinuity, where the object depth changes drastically.
|
||||
The surface normals model is the least confident in fine-grained structures, such as hair, and dark areas, such as the collar.
|
||||
|
||||
## Frame-by-frame Video Processing with Temporal Consistency
|
||||
|
||||
Due to Marigold's generative nature, each prediction is unique and defined by the random noise sampled for the latent
|
||||
initialization.
|
||||
This becomes an obvious drawback compared to traditional end-to-end dense regression networks, as exemplified in the
|
||||
following videos:
|
||||
Due to Marigold's generative nature, each prediction is unique and defined by the random noise sampled for the latent initialization.
|
||||
This becomes an obvious drawback compared to traditional end-to-end dense regression networks, as exemplified in the following videos:
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div style="flex: 1 1 50%; max-width: 50%;">
|
||||
@@ -373,32 +336,26 @@ following videos:
|
||||
</div>
|
||||
</div>
|
||||
|
||||
To address this issue, it is possible to pass `latents` argument to the pipelines, which defines the starting point of
|
||||
diffusion.
|
||||
Empirically, we found that a convex combination of the very same starting point noise latent and the latent
|
||||
corresponding to the previous frame prediction give sufficiently smooth results, as implemented in the snippet below:
|
||||
To address this issue, it is possible to pass `latents` argument to the pipelines, which defines the starting point of diffusion.
|
||||
Empirically, we found that a convex combination of the very same starting point noise latent and the latent corresponding to the previous frame prediction give sufficiently smooth results, as implemented in the snippet below:
|
||||
|
||||
```python
|
||||
import imageio
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
device = "cuda"
|
||||
path_in = "https://huggingface.co/spaces/prs-eth/marigold-lcm/resolve/c7adb5427947d2680944f898cd91d386bf0d4924/files/video/obama.mp4"
|
||||
path_in = "obama.mp4"
|
||||
path_out = "obama_depth.gif"
|
||||
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
||||
"prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
|
||||
"prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
|
||||
).to(device)
|
||||
pipe.vae = diffusers.AutoencoderTiny.from_pretrained(
|
||||
"madebyollin/taesd", torch_dtype=torch.float16
|
||||
).to(device)
|
||||
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
||||
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
with imageio.get_reader(path_in) as reader:
|
||||
@@ -416,11 +373,7 @@ with imageio.get_reader(path_in) as reader:
|
||||
latents = 0.9 * latents + 0.1 * last_frame_latent
|
||||
|
||||
depth = pipe(
|
||||
frame,
|
||||
num_inference_steps=1,
|
||||
match_input_resolution=False,
|
||||
latents=latents,
|
||||
output_latent=True,
|
||||
frame, match_input_resolution=False, latents=latents, output_latent=True
|
||||
)
|
||||
last_frame_latent = depth.latent
|
||||
out.append(pipe.image_processor.visualize_depth(depth.prediction)[0])
|
||||
@@ -429,8 +382,7 @@ with imageio.get_reader(path_in) as reader:
|
||||
```
|
||||
|
||||
Here, the diffusion process starts from the given computed latent.
|
||||
The pipeline sets `output_latent=True` to access `out.latent` and computes its contribution to the next frame's latent
|
||||
initialization.
|
||||
The pipeline sets `output_latent=True` to access `out.latent` and computes its contribution to the next frame's latent initialization.
|
||||
The result is much more stable now:
|
||||
|
||||
<div class="flex gap-4">
|
||||
@@ -462,7 +414,7 @@ image = diffusers.utils.load_image(
|
||||
)
|
||||
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
||||
"prs-eth/marigold-depth-v1-1", torch_dtype=torch.float16, variant="fp16"
|
||||
"prs-eth/marigold-depth-lcm-v1-0", torch_dtype=torch.float16, variant="fp16"
|
||||
).to(device)
|
||||
|
||||
depth_image = pipe(image, generator=generator).prediction
|
||||
@@ -511,95 +463,4 @@ controlnet_out[0].save("motorcycle_controlnet_out.png")
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Quantitative Evaluation
|
||||
|
||||
To evaluate Marigold quantitatively in standard leaderboards and benchmarks (such as NYU, KITTI, and other datasets),
|
||||
follow the evaluation protocol outlined in the paper: load the full precision fp32 model and use appropriate values
|
||||
for `num_inference_steps` and `ensemble_size`.
|
||||
Optionally seed randomness to ensure reproducibility.
|
||||
Maximizing `batch_size` will deliver maximum device utilization.
|
||||
|
||||
```python
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
device = "cuda"
|
||||
seed = 2024
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained("prs-eth/marigold-depth-v1-1").to(device)
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
depth = pipe(
|
||||
image,
|
||||
num_inference_steps=4, # set according to the evaluation protocol from the paper
|
||||
ensemble_size=10, # set according to the evaluation protocol from the paper
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
# evaluate metrics
|
||||
```
|
||||
|
||||
## Using Predictive Uncertainty
|
||||
|
||||
The ensembling mechanism built into Marigold pipelines combines multiple predictions obtained from different random
|
||||
latents.
|
||||
As a side effect, it can be used to quantify epistemic (model) uncertainty; simply specify `ensemble_size` greater
|
||||
or equal than 3 and set `output_uncertainty=True`.
|
||||
The resulting uncertainty will be available in the `uncertainty` field of the output.
|
||||
It can be visualized as follows:
|
||||
|
||||
```python
|
||||
import diffusers
|
||||
import torch
|
||||
|
||||
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
|
||||
"prs-eth/marigold-depth-v1-1", variant="fp16", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = diffusers.utils.load_image("https://marigoldmonodepth.github.io/images/einstein.jpg")
|
||||
|
||||
depth = pipe(
|
||||
image,
|
||||
ensemble_size=10, # any number >= 3
|
||||
output_uncertainty=True,
|
||||
)
|
||||
|
||||
uncertainty = pipe.image_processor.visualize_uncertainty(depth.uncertainty)
|
||||
uncertainty[0].save("einstein_depth_uncertainty.png")
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div style="flex: 1 1 33%; max-width: 33%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_depth_uncertainty.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Depth uncertainty
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 33%; max-width: 33%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/marigold/marigold_einstein_normals_uncertainty.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Surface normals uncertainty
|
||||
</figcaption>
|
||||
</div>
|
||||
<div style="flex: 1 1 33%; max-width: 33%;">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/4f83035d84a24e5ec44fdda129b1d51eba12ce04/marigold/marigold_einstein_albedo_uncertainty.png"/>
|
||||
<figcaption class="mt-1 text-center text-sm text-gray-500">
|
||||
Albedo uncertainty
|
||||
</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
The interpretation of uncertainty is easy: higher values (white) correspond to pixels, where the model struggles to
|
||||
make consistent predictions.
|
||||
- The depth model exhibits the most uncertainty around discontinuities, where object depth changes abruptly.
|
||||
- The surface normals model is least confident in fine-grained structures like hair and in dark regions such as the
|
||||
collar area.
|
||||
- Albedo uncertainty is represented as an RGB image, as it captures uncertainty independently for each color channel,
|
||||
unlike depth and surface normals. It is also higher in shaded regions and at discontinuities.
|
||||
|
||||
## Conclusion
|
||||
|
||||
We hope Marigold proves valuable for your downstream tasks, whether as part of a broader generative workflow or for
|
||||
perception-based applications like 3D reconstruction.
|
||||
Hopefully, you will find Marigold useful for solving your downstream tasks, be it a part of a more broad generative workflow, or a perception task, such as 3D reconstruction.
|
||||
|
||||
@@ -215,7 +215,7 @@ image
|
||||
|
||||
Prompt weighting provides a way to emphasize or de-emphasize certain parts of a prompt, allowing for more control over the generated image. A prompt can include several concepts, which gets turned into contextualized text embeddings. The embeddings are used by the model to condition its cross-attention layers to generate an image (read the Stable Diffusion [blog post](https://huggingface.co/blog/stable_diffusion) to learn more about how it works).
|
||||
|
||||
Prompt weighting works by increasing or decreasing the scale of the text embedding vector that corresponds to its concept in the prompt because you may not necessarily want the model to focus on all concepts equally. The easiest way to prepare the prompt embeddings is to use [Stable Diffusion Long Prompt Weighted Embedding](https://github.com/xhinker/sd_embed) (sd_embed). Once you have the prompt-weighted embeddings, you can pass them to any pipeline that has a [prompt_embeds](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) (and optionally [negative_prompt_embeds](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.negative_prompt_embeds)) parameter, such as [`StableDiffusionPipeline`], [`StableDiffusionControlNetPipeline`], and [`StableDiffusionXLPipeline`].
|
||||
Prompt weighting works by increasing or decreasing the scale of the text embedding vector that corresponds to its concept in the prompt because you may not necessarily want the model to focus on all concepts equally. The easiest way to prepare the prompt-weighted embeddings is to use [Compel](https://github.com/damian0815/compel), a text prompt-weighting and blending library. Once you have the prompt-weighted embeddings, you can pass them to any pipeline that has a [`prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) (and optionally [`negative_prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.negative_prompt_embeds)) parameter, such as [`StableDiffusionPipeline`], [`StableDiffusionControlNetPipeline`], and [`StableDiffusionXLPipeline`].
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -223,99 +223,136 @@ If your favorite pipeline doesn't have a `prompt_embeds` parameter, please open
|
||||
|
||||
</Tip>
|
||||
|
||||
This guide will show you how to weight your prompts with sd_embed.
|
||||
This guide will show you how to weight and blend your prompts with Compel in 🤗 Diffusers.
|
||||
|
||||
Before you begin, make sure you have the latest version of sd_embed installed:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/xhinker/sd_embed.git@main
|
||||
```
|
||||
|
||||
For this example, let's use [`StableDiffusionXLPipeline`].
|
||||
Before you begin, make sure you have the latest version of Compel installed:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline, UniPCMultistepScheduler
|
||||
# uncomment to install in Colab
|
||||
#!pip install compel --upgrade
|
||||
```
|
||||
|
||||
For this guide, let's generate an image with the prompt `"a red cat playing with a ball"` using the [`StableDiffusionPipeline`]:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained("Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_safetensors=True)
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
To upweight or downweight a concept, surround the text with parentheses. More parentheses applies a heavier weight on the text. You can also append a numerical multiplier to the text to indicate how much you want to increase or decrease its weights by.
|
||||
prompt = "a red cat playing with a ball"
|
||||
|
||||
| format | multiplier |
|
||||
|---|---|
|
||||
| `(hippo)` | increase by 1.1x |
|
||||
| `((hippo))` | increase by 1.21x |
|
||||
| `(hippo:1.5)` | increase by 1.5x |
|
||||
| `(hippo:0.5)` | decrease by 4x |
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
|
||||
Create a prompt and use a combination of parentheses and numerical multipliers to upweight various text.
|
||||
|
||||
```py
|
||||
from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl
|
||||
|
||||
prompt = """A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
|
||||
This imaginative creature features the distinctive, bulky body of a hippo,
|
||||
but with a texture and appearance resembling a golden-brown, crispy waffle.
|
||||
The creature might have elements like waffle squares across its skin and a syrup-like sheen.
|
||||
It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
|
||||
possibly including oversized utensils or plates in the background.
|
||||
The image should evoke a sense of playful absurdity and culinary fantasy.
|
||||
"""
|
||||
|
||||
neg_prompt = """\
|
||||
skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
|
||||
(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
|
||||
extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
|
||||
(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
|
||||
bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
|
||||
(normal quality:2),lowres,((monochrome)),((grayscale))
|
||||
"""
|
||||
```
|
||||
|
||||
Use the `get_weighted_text_embeddings_sdxl` function to generate the prompt embeddings and the negative prompt embeddings. It'll also generated the pooled and negative pooled prompt embeddings since you're using the SDXL model.
|
||||
|
||||
> [!TIP]
|
||||
> You can safely ignore the error message below about the token index length exceeding the models maximum sequence length. All your tokens will be used in the embedding process.
|
||||
>
|
||||
> ```
|
||||
> Token indices sequence length is longer than the specified maximum sequence length for this model
|
||||
> ```
|
||||
|
||||
```py
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_neg_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds
|
||||
) = get_weighted_text_embeddings_sdxl(
|
||||
pipe,
|
||||
prompt=prompt,
|
||||
neg_prompt=neg_prompt
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=prompt_neg_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
num_inference_steps=30,
|
||||
height=1024,
|
||||
width=1024 + 512,
|
||||
guidance_scale=4.0,
|
||||
generator=torch.Generator("cuda").manual_seed(2)
|
||||
).images[0]
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd_embed_sdxl.png"/>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_0.png"/>
|
||||
</div>
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [sd_embed](https://github.com/xhinker/sd_embed) repository for additional details about long prompt weighting for FLUX.1, Stable Cascade, and Stable Diffusion 1.5.
|
||||
### Weighting
|
||||
|
||||
You'll notice there is no "ball" in the image! Let's use compel to upweight the concept of "ball" in the prompt. Create a [`Compel`](https://github.com/damian0815/compel/blob/main/doc/compel.md#compel-objects) object, and pass it a tokenizer and text encoder:
|
||||
|
||||
```py
|
||||
from compel import Compel
|
||||
|
||||
compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
|
||||
```
|
||||
|
||||
compel uses `+` or `-` to increase or decrease the weight of a word in the prompt. To increase the weight of "ball":
|
||||
|
||||
<Tip>
|
||||
|
||||
`+` corresponds to the value `1.1`, `++` corresponds to `1.1^2`, and so on. Similarly, `-` corresponds to `0.9` and `--` corresponds to `0.9^2`. Feel free to experiment with adding more `+` or `-` in your prompt!
|
||||
|
||||
</Tip>
|
||||
|
||||
```py
|
||||
prompt = "a red cat playing with a ball++"
|
||||
```
|
||||
|
||||
Pass the prompt to `compel_proc` to create the new prompt embeddings which are passed to the pipeline:
|
||||
|
||||
```py
|
||||
prompt_embeds = compel_proc(prompt)
|
||||
generator = torch.manual_seed(33)
|
||||
|
||||
image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_1.png"/>
|
||||
</div>
|
||||
|
||||
To downweight parts of the prompt, use the `-` suffix:
|
||||
|
||||
```py
|
||||
prompt = "a red------- cat playing with a ball"
|
||||
prompt_embeds = compel_proc(prompt)
|
||||
|
||||
generator = torch.manual_seed(33)
|
||||
|
||||
image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"/>
|
||||
</div>
|
||||
|
||||
You can even up or downweight multiple concepts in the same prompt:
|
||||
|
||||
```py
|
||||
prompt = "a red cat++ playing with a ball----"
|
||||
prompt_embeds = compel_proc(prompt)
|
||||
|
||||
generator = torch.manual_seed(33)
|
||||
|
||||
image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-pos-neg.png"/>
|
||||
</div>
|
||||
|
||||
### Blending
|
||||
|
||||
You can also create a weighted *blend* of prompts by adding `.blend()` to a list of prompts and passing it some weights. Your blend may not always produce the result you expect because it breaks some assumptions about how the text encoder functions, so just have fun and experiment with it!
|
||||
|
||||
```py
|
||||
prompt_embeds = compel_proc('("a red cat playing with a ball", "jungle").blend(0.7, 0.8)')
|
||||
generator = torch.Generator(device="cuda").manual_seed(33)
|
||||
|
||||
image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-blend.png"/>
|
||||
</div>
|
||||
|
||||
### Conjunction
|
||||
|
||||
A conjunction diffuses each prompt independently and concatenates their results by their weighted sum. Add `.and()` to the end of a list of prompts to create a conjunction:
|
||||
|
||||
```py
|
||||
prompt_embeds = compel_proc('["a red cat", "playing with a", "ball"].and()')
|
||||
generator = torch.Generator(device="cuda").manual_seed(55)
|
||||
|
||||
image = pipe(prompt_embeds=prompt_embeds, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-conj.png"/>
|
||||
</div>
|
||||
|
||||
### Textual inversion
|
||||
|
||||
@@ -326,63 +363,35 @@ Create a pipeline and use the [`~loaders.TextualInversionLoaderMixin.load_textua
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from compel import Compel, DiffusersTextualInversionManager
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16,
|
||||
use_safetensors=True, variant="fp16").to("cuda")
|
||||
pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
|
||||
```
|
||||
|
||||
Add the `<midjourney-style>` text to the prompt to trigger the textual inversion.
|
||||
Compel provides a `DiffusersTextualInversionManager` class to simplify prompt weighting with textual inversion. Instantiate `DiffusersTextualInversionManager` and pass it to the `Compel` class:
|
||||
|
||||
```py
|
||||
from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15
|
||||
|
||||
prompt = """<midjourney-style> A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
|
||||
This imaginative creature features the distinctive, bulky body of a hippo,
|
||||
but with a texture and appearance resembling a golden-brown, crispy waffle.
|
||||
The creature might have elements like waffle squares across its skin and a syrup-like sheen.
|
||||
It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
|
||||
possibly including oversized utensils or plates in the background.
|
||||
The image should evoke a sense of playful absurdity and culinary fantasy.
|
||||
"""
|
||||
|
||||
neg_prompt = """\
|
||||
skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
|
||||
(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
|
||||
extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
|
||||
(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
|
||||
bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
|
||||
(normal quality:2),lowres,((monochrome)),((grayscale))
|
||||
"""
|
||||
textual_inversion_manager = DiffusersTextualInversionManager(pipe)
|
||||
compel_proc = Compel(
|
||||
tokenizer=pipe.tokenizer,
|
||||
text_encoder=pipe.text_encoder,
|
||||
textual_inversion_manager=textual_inversion_manager)
|
||||
```
|
||||
|
||||
Use the `get_weighted_text_embeddings_sd15` function to generate the prompt embeddings and the negative prompt embeddings.
|
||||
Incorporate the concept to condition a prompt with using the `<concept>` syntax:
|
||||
|
||||
```py
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_neg_embeds,
|
||||
) = get_weighted_text_embeddings_sd15(
|
||||
pipe,
|
||||
prompt=prompt,
|
||||
neg_prompt=neg_prompt
|
||||
)
|
||||
prompt_embeds = compel_proc('("A red cat++ playing with a ball <midjourney-style>")')
|
||||
|
||||
image = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=prompt_neg_embeds,
|
||||
height=768,
|
||||
width=896,
|
||||
guidance_scale=4.0,
|
||||
generator=torch.Generator("cuda").manual_seed(2)
|
||||
).images[0]
|
||||
image = pipe(prompt_embeds=prompt_embeds).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd_embed_textual_inversion.png"/>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-text-inversion.png"/>
|
||||
</div>
|
||||
|
||||
### DreamBooth
|
||||
@@ -392,44 +401,70 @@ image
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, UniPCMultistepScheduler
|
||||
from compel import Compel
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("sd-dreambooth-library/dndcoverart-v1", torch_dtype=torch.float16).to("cuda")
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
```
|
||||
|
||||
Depending on the model you use, you'll need to incorporate the model's unique identifier into your prompt. For example, the `dndcoverart-v1` model uses the identifier `dndcoverart`:
|
||||
Create a `Compel` class with a tokenizer and text encoder, and pass your prompt to it. Depending on the model you use, you'll need to incorporate the model's unique identifier into your prompt. For example, the `dndcoverart-v1` model uses the identifier `dndcoverart`:
|
||||
|
||||
```py
|
||||
from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15
|
||||
|
||||
prompt = """dndcoverart of A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
|
||||
This imaginative creature features the distinctive, bulky body of a hippo,
|
||||
but with a texture and appearance resembling a golden-brown, crispy waffle.
|
||||
The creature might have elements like waffle squares across its skin and a syrup-like sheen.
|
||||
It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
|
||||
possibly including oversized utensils or plates in the background.
|
||||
The image should evoke a sense of playful absurdity and culinary fantasy.
|
||||
"""
|
||||
|
||||
neg_prompt = """\
|
||||
skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
|
||||
(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
|
||||
extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
|
||||
(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
|
||||
bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
|
||||
(normal quality:2),lowres,((monochrome)),((grayscale))
|
||||
"""
|
||||
|
||||
(
|
||||
prompt_embeds
|
||||
, prompt_neg_embeds
|
||||
) = get_weighted_text_embeddings_sd15(
|
||||
pipe
|
||||
, prompt = prompt
|
||||
, neg_prompt = neg_prompt
|
||||
)
|
||||
compel_proc = Compel(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder)
|
||||
prompt_embeds = compel_proc('("magazine cover of a dndcoverart dragon, high quality, intricate details, larry elmore art style").and()')
|
||||
image = pipe(prompt_embeds=prompt_embeds).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd_embed_dreambooth.png"/>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-dreambooth.png"/>
|
||||
</div>
|
||||
|
||||
### Stable Diffusion XL
|
||||
|
||||
Stable Diffusion XL (SDXL) has two tokenizers and text encoders so it's usage is a bit different. To address this, you should pass both tokenizers and encoders to the `Compel` class:
|
||||
|
||||
```py
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import make_image_grid
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
compel = Compel(
|
||||
tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
|
||||
text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
|
||||
requires_pooled=[False, True]
|
||||
)
|
||||
```
|
||||
|
||||
This time, let's upweight "ball" by a factor of 1.5 for the first prompt, and downweight "ball" by 0.6 for the second prompt. The [`StableDiffusionXLPipeline`] also requires [`pooled_prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLInpaintPipeline.__call__.pooled_prompt_embeds) (and optionally [`negative_pooled_prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLInpaintPipeline.__call__.negative_pooled_prompt_embeds)) so you should pass those to the pipeline along with the conditioning tensors:
|
||||
|
||||
```py
|
||||
# apply weights
|
||||
prompt = ["a red cat playing with a (ball)1.5", "a red cat playing with a (ball)0.6"]
|
||||
conditioning, pooled = compel(prompt)
|
||||
|
||||
# generate image
|
||||
generator = [torch.Generator().manual_seed(33) for _ in range(len(prompt))]
|
||||
images = pipeline(prompt_embeds=conditioning, pooled_prompt_embeds=pooled, generator=generator, num_inference_steps=30).images
|
||||
make_image_grid(images, rows=1, cols=2)
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/sdxl_ball1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">"a red cat playing with a (ball)1.5"</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/sdxl_ball2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">"a red cat playing with a (ball)0.6"</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -227,7 +227,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
with autocast_ctx:
|
||||
@@ -880,7 +880,9 @@ class TokenEmbeddingsHandler:
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
|
||||
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
embeds = (
|
||||
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
|
||||
)
|
||||
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
|
||||
new_token_embeddings = embeds.weight.data[train_ids]
|
||||
|
||||
@@ -902,7 +904,9 @@ class TokenEmbeddingsHandler:
|
||||
@torch.no_grad()
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
|
||||
embeds = (
|
||||
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
|
||||
)
|
||||
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
|
||||
embeds.weight.data[index_no_updates] = (
|
||||
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
|
||||
@@ -1745,7 +1749,7 @@ def main(args):
|
||||
if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
|
||||
text_lora_parameters_two = []
|
||||
for name, param in text_encoder_two.named_parameters():
|
||||
if "shared" in name:
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param.data = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
|
||||
@@ -1883,11 +1883,7 @@ def main(args):
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = (
|
||||
torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
if args.seed is not None
|
||||
else None
|
||||
)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -1991,9 +1987,7 @@ def main(args):
|
||||
)
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = (
|
||||
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -269,7 +269,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
|
||||
@@ -722,7 +722,7 @@ def log_validation(
|
||||
# pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
|
||||
videos = []
|
||||
for _ in range(args.num_validation_videos):
|
||||
|
||||
@@ -739,7 +739,7 @@ def log_validation(
|
||||
# pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
|
||||
videos = []
|
||||
for _ in range(args.num_validation_videos):
|
||||
|
||||
@@ -53,7 +53,6 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Stable Diffusion Mixture Tiling Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SD 1.5](#stable-diffusion-mixture-tiling-pipeline-sd-15) | [](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
|
||||
| Stable Diffusion Mixture Canvas Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending. Works by defining a list of Text2Image region objects that detail the region of influence of each diffuser. | [Stable Diffusion Mixture Canvas Pipeline SD 1.5](#stable-diffusion-mixture-canvas-pipeline-sd-15) | [](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
|
||||
| Stable Diffusion Mixture Tiling Pipeline SDXL | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SDXL](#stable-diffusion-mixture-tiling-pipeline-sdxl) | [](https://huggingface.co/spaces/elismasilva/mixture-of-diffusers-sdxl-tiling) | [Eliseu Silva](https://github.com/DEVAIEXP/) |
|
||||
| Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL | This is an advanced pipeline that leverages ControlNet Tile and Mixture-of-Diffusers techniques, integrating tile diffusion directly into the latent space denoising process. Designed to overcome the limitations of conventional pixel-space tile processing, this pipeline delivers Super Resolution (SR) upscaling for higher-quality images, reduced processing time, and greater adaptability. | [Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL](#stable-diffusion-mod-controlnet-tile-sr-pipeline-sdxl) | [](https://huggingface.co/spaces/elismasilva/mod-control-tile-upscaler-sdxl) | [Eliseu Silva](https://github.com/DEVAIEXP/) |
|
||||
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
|
||||
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
|
||||
@@ -83,7 +82,6 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
|
||||
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
|
||||
| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
|
||||
| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|
|
||||
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -2632,103 +2630,6 @@ image = pipe(
|
||||
|
||||

|
||||
|
||||
### Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL
|
||||
|
||||
This pipeline implements the [MoD (Mixture-of-Diffusers)]("https://arxiv.org/pdf/2408.06072") tiled diffusion technique and combines it with SDXL's ControlNet Tile process to generate SR images.
|
||||
|
||||
This works better with 4x scales, but you can try adjusts parameters to higher scales.
|
||||
|
||||
````python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, ControlNetUnionModel, AutoencoderKL, UniPCMultistepScheduler, UNet2DConditionModel
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
|
||||
device = "cuda"
|
||||
|
||||
# Initialize the models and pipeline
|
||||
controlnet = ControlNetUnionModel.from_pretrained(
|
||||
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
|
||||
).to(device=device)
|
||||
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device=device)
|
||||
|
||||
model_id = "SG161222/RealVisXL_V5.0"
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.float16,
|
||||
vae=vae,
|
||||
controlnet=controlnet,
|
||||
custom_pipeline="mod_controlnet_tile_sr_sdxl",
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
).to(device)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
|
||||
|
||||
#pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
|
||||
pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
|
||||
pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
|
||||
|
||||
# Set selected scheduler
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# Load image
|
||||
control_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/1.jpg")
|
||||
original_height = control_image.height
|
||||
original_width = control_image.width
|
||||
print(f"Current resolution: H:{original_height} x W:{original_width}")
|
||||
|
||||
# Pre-upscale image for tiling
|
||||
resolution = 4096
|
||||
tile_gaussian_sigma = 0.3
|
||||
max_tile_size = 1024 # or 1280
|
||||
|
||||
current_size = max(control_image.size)
|
||||
scale_factor = max(2, resolution / current_size)
|
||||
new_size = (int(control_image.width * scale_factor), int(control_image.height * scale_factor))
|
||||
image = control_image.resize(new_size, Image.LANCZOS)
|
||||
|
||||
# Update target height and width
|
||||
target_height = image.height
|
||||
target_width = image.width
|
||||
print(f"Target resolution: H:{target_height} x W:{target_width}")
|
||||
|
||||
# Calculate overlap size
|
||||
normal_tile_overlap, border_tile_overlap = pipe.calculate_overlap(target_width, target_height)
|
||||
|
||||
# Set other params
|
||||
tile_weighting_method = pipe.TileWeightingMethod.COSINE.value
|
||||
guidance_scale = 4
|
||||
num_inference_steps = 35
|
||||
denoising_strenght = 0.65
|
||||
controlnet_strength = 1.0
|
||||
prompt = "high-quality, noise-free edges, high quality, 4k, hd, 8k"
|
||||
negative_prompt = "blurry, pixelated, noisy, low resolution, artifacts, poor details"
|
||||
|
||||
# Image generation
|
||||
generated_image = pipe(
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
control_mode=[6],
|
||||
controlnet_conditioning_scale=float(controlnet_strength),
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
normal_tile_overlap=normal_tile_overlap,
|
||||
border_tile_overlap=border_tile_overlap,
|
||||
height=target_height,
|
||||
width=target_width,
|
||||
original_size=(original_width, original_height),
|
||||
target_size=(target_width, target_height),
|
||||
guidance_scale=guidance_scale,
|
||||
strength=float(denoising_strenght),
|
||||
tile_weighting_method=tile_weighting_method,
|
||||
max_tile_size=max_tile_size,
|
||||
tile_gaussian_sigma=float(tile_gaussian_sigma),
|
||||
num_inference_steps=num_inference_steps,
|
||||
)["images"][0]
|
||||
````
|
||||

|
||||
|
||||
### TensorRT Inpainting Stable Diffusion Pipeline
|
||||
|
||||
The TensorRT Pipeline can be used to accelerate the Inpainting Stable Diffusion Inference run.
|
||||
@@ -5223,39 +5124,3 @@ with torch.no_grad():
|
||||
|
||||
In the folder examples/pixart there is also a script that can be used to train new models.
|
||||
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
|
||||
|
||||
# CogVideoX DDIM Inversion Pipeline
|
||||
|
||||
This implementation performs DDIM inversion on the video based on CogVideoX and uses guided attention to reconstruct or edit the inversion latents.
|
||||
|
||||
## Example Usage
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from examples.community.cogvideox_ddim_inversion import CogVideoXPipelineForDDIMInversion
|
||||
|
||||
|
||||
# Load pretrained pipeline
|
||||
pipeline = CogVideoXPipelineForDDIMInversion.from_pretrained(
|
||||
"THUDM/CogVideoX1.5-5B",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
# Run DDIM inversion, and the videos will be generated in the output_path
|
||||
output = pipeline_for_inversion(
|
||||
prompt="prompt that describes the edited video",
|
||||
video_path="path/to/input.mp4",
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=50,
|
||||
skip_frames_start=0,
|
||||
skip_frames_end=0,
|
||||
frame_sample_step=None,
|
||||
max_num_frames=81,
|
||||
width=720,
|
||||
height=480,
|
||||
seed=42,
|
||||
)
|
||||
pipeline.export_latents_to_video(output.inverse_latents[-1], "path/to/inverse_video.mp4", fps=8)
|
||||
pipeline.export_latents_to_video(output.recon_latents[-1], "path/to/recon_video.mp4", fps=8)
|
||||
```
|
||||
|
||||
@@ -92,13 +92,9 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
token = kwargs.pop("token", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
print(f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.")
|
||||
|
||||
alpha = kwargs.pop("alpha", 0.5)
|
||||
interp = kwargs.pop("interp", None)
|
||||
|
||||
|
||||
@@ -1,645 +0,0 @@
|
||||
"""
|
||||
This script performs DDIM inversion for video frames using a pre-trained model and generates
|
||||
a video reconstruction based on a provided prompt. It utilizes the CogVideoX pipeline to
|
||||
process video frames, apply the DDIM inverse scheduler, and produce an output video.
|
||||
|
||||
**Please notice that this script is based on the CogVideoX 5B model, and would not generate
|
||||
a good result for 2B variants.**
|
||||
|
||||
Usage:
|
||||
python cogvideox_ddim_inversion.py
|
||||
--model-path /path/to/model
|
||||
--prompt "a prompt"
|
||||
--video-path /path/to/video.mp4
|
||||
--output-path /path/to/output
|
||||
|
||||
For more details about the cli arguments, please run `python cogvideox_ddim_inversion.py --help`.
|
||||
|
||||
Author:
|
||||
LittleNyima <littlenyima[at]163[dot]com>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0
|
||||
from diffusers.models.autoencoders import AutoencoderKLCogVideoX
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
|
||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps
|
||||
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
|
||||
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error.
|
||||
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
||||
import decord # isort: skip
|
||||
|
||||
|
||||
class DDIMInversionArguments(TypedDict):
|
||||
model_path: str
|
||||
prompt: str
|
||||
video_path: str
|
||||
output_path: str
|
||||
guidance_scale: float
|
||||
num_inference_steps: int
|
||||
skip_frames_start: int
|
||||
skip_frames_end: int
|
||||
frame_sample_step: Optional[int]
|
||||
max_num_frames: int
|
||||
width: int
|
||||
height: int
|
||||
fps: int
|
||||
dtype: torch.dtype
|
||||
seed: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def get_args() -> DDIMInversionArguments:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model_path", type=str, required=True, help="Path of the pretrained model")
|
||||
parser.add_argument("--prompt", type=str, required=True, help="Prompt for the direct sample procedure")
|
||||
parser.add_argument("--video_path", type=str, required=True, help="Path of the video for inversion")
|
||||
parser.add_argument("--output_path", type=str, default="output", help="Path of the output videos")
|
||||
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
|
||||
parser.add_argument("--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start")
|
||||
parser.add_argument("--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end")
|
||||
parser.add_argument("--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames")
|
||||
parser.add_argument("--max_num_frames", type=int, default=81, help="Max number of sampled frames")
|
||||
parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames")
|
||||
parser.add_argument("--height", type=int, default=480, help="Resized height of the video frames")
|
||||
parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos")
|
||||
parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
|
||||
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference")
|
||||
|
||||
args = parser.parse_args()
|
||||
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
|
||||
args.device = torch.device(args.device)
|
||||
|
||||
return DDIMInversionArguments(**vars(args))
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def calculate_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn: Attention,
|
||||
batch_size: int,
|
||||
image_seq_length: int,
|
||||
text_seq_length: int,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
image_rotary_emb: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Core attention computation with inversion-guided RoPE integration.
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`): `[batch_size, seq_len, dim]` query tensor
|
||||
key (`torch.Tensor`): `[batch_size, seq_len, dim]` key tensor
|
||||
value (`torch.Tensor`): `[batch_size, seq_len, dim]` value tensor
|
||||
attn (`Attention`): Parent attention module with projection layers
|
||||
batch_size (`int`): Effective batch size (after chunk splitting)
|
||||
image_seq_length (`int`): Length of image feature sequence
|
||||
text_seq_length (`int`): Length of text feature sequence
|
||||
attention_mask (`Optional[torch.Tensor]`): Attention mask tensor
|
||||
image_rotary_emb (`Optional[torch.Tensor]`): Rotary embeddings for image positions
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
(1) hidden_states: [batch_size, image_seq_length, dim] processed image features
|
||||
(2) encoder_hidden_states: [batch_size, text_seq_length, dim] processed text features
|
||||
"""
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
if key.size(2) == query.size(2): # Attention for reference hidden states
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
else: # RoPE should be applied to each group of image tokens
|
||||
key[:, :, text_seq_length : text_seq_length + image_seq_length] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length : text_seq_length + image_seq_length], image_rotary_emb
|
||||
)
|
||||
key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb
|
||||
)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Process the dual-path attention for the inversion-guided denoising procedure.
|
||||
|
||||
Args:
|
||||
attn (`Attention`): Parent attention module
|
||||
hidden_states (`torch.Tensor`): `[batch_size, image_seq_len, dim]` Image tokens
|
||||
encoder_hidden_states (`torch.Tensor`): `[batch_size, text_seq_len, dim]` Text tokens
|
||||
attention_mask (`Optional[torch.Tensor]`): Optional attention mask
|
||||
image_rotary_emb (`Optional[torch.Tensor]`): Rotary embeddings for image tokens
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
(1) Final hidden states: `[batch_size, image_seq_length, dim]` Resulting image tokens
|
||||
(2) Final encoder states: `[batch_size, text_seq_length, dim]` Resulting text tokens
|
||||
"""
|
||||
image_seq_length = hidden_states.size(1)
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query, query_reference = query.chunk(2)
|
||||
key, key_reference = key.chunk(2)
|
||||
value, value_reference = value.chunk(2)
|
||||
batch_size = batch_size // 2
|
||||
|
||||
hidden_states, encoder_hidden_states = self.calculate_attention(
|
||||
query=query,
|
||||
key=torch.cat((key, key_reference), dim=1),
|
||||
value=torch.cat((value, value_reference), dim=1),
|
||||
attn=attn,
|
||||
batch_size=batch_size,
|
||||
image_seq_length=image_seq_length,
|
||||
text_seq_length=text_seq_length,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states_reference, encoder_hidden_states_reference = self.calculate_attention(
|
||||
query=query_reference,
|
||||
key=key_reference,
|
||||
value=value_reference,
|
||||
attn=attn,
|
||||
batch_size=batch_size,
|
||||
image_seq_length=image_seq_length,
|
||||
text_seq_length=text_seq_length,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
return (
|
||||
torch.cat((hidden_states, hidden_states_reference)),
|
||||
torch.cat((encoder_hidden_states, encoder_hidden_states_reference)),
|
||||
)
|
||||
|
||||
|
||||
class OverrideAttnProcessors:
|
||||
r"""
|
||||
Context manager for temporarily overriding attention processors in CogVideo transformer blocks.
|
||||
|
||||
Designed for DDIM inversion process, replaces original attention processors with
|
||||
`CogVideoXAttnProcessor2_0ForDDIMInversion` and restores them upon exit. Uses Python context manager
|
||||
pattern to safely manage processor replacement.
|
||||
|
||||
Typical usage:
|
||||
```python
|
||||
with OverrideAttnProcessors(transformer):
|
||||
# Perform DDIM inversion operations
|
||||
```
|
||||
|
||||
Args:
|
||||
transformer (`CogVideoXTransformer3DModel`):
|
||||
The transformer model containing attention blocks to be modified. Should have
|
||||
`transformer_blocks` attribute containing `CogVideoXBlock` instances.
|
||||
"""
|
||||
|
||||
def __init__(self, transformer: CogVideoXTransformer3DModel):
|
||||
self.transformer = transformer
|
||||
self.original_processors = {}
|
||||
|
||||
def __enter__(self):
|
||||
for block in self.transformer.transformer_blocks:
|
||||
block = cast(CogVideoXBlock, block)
|
||||
self.original_processors[id(block)] = block.attn1.get_processor()
|
||||
block.attn1.set_processor(CogVideoXAttnProcessor2_0ForDDIMInversion())
|
||||
|
||||
def __exit__(self, _0, _1, _2):
|
||||
for block in self.transformer.transformer_blocks:
|
||||
block = cast(CogVideoXBlock, block)
|
||||
block.attn1.set_processor(self.original_processors[id(block)])
|
||||
|
||||
|
||||
def get_video_frames(
|
||||
video_path: str,
|
||||
width: int,
|
||||
height: int,
|
||||
skip_frames_start: int,
|
||||
skip_frames_end: int,
|
||||
max_num_frames: int,
|
||||
frame_sample_step: Optional[int],
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Extract and preprocess video frames from a video file for VAE processing.
|
||||
|
||||
Args:
|
||||
video_path (`str`): Path to input video file
|
||||
width (`int`): Target frame width for decoding
|
||||
height (`int`): Target frame height for decoding
|
||||
skip_frames_start (`int`): Number of frames to skip at video start
|
||||
skip_frames_end (`int`): Number of frames to skip at video end
|
||||
max_num_frames (`int`): Maximum allowed number of output frames
|
||||
frame_sample_step (`Optional[int]`):
|
||||
Frame sampling step size. If None, automatically calculated as:
|
||||
(total_frames - skipped_frames) // max_num_frames
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Preprocessed frames in `[F, C, H, W]` format where:
|
||||
- `F`: Number of frames (adjusted to 4k + 1 for VAE compatibility)
|
||||
- `C`: Channels (3 for RGB)
|
||||
- `H`: Frame height
|
||||
- `W`: Frame width
|
||||
"""
|
||||
with decord.bridge.use_torch():
|
||||
video_reader = decord.VideoReader(uri=video_path, width=width, height=height)
|
||||
video_num_frames = len(video_reader)
|
||||
start_frame = min(skip_frames_start, video_num_frames)
|
||||
end_frame = max(0, video_num_frames - skip_frames_end)
|
||||
|
||||
if end_frame <= start_frame:
|
||||
indices = [start_frame]
|
||||
elif end_frame - start_frame <= max_num_frames:
|
||||
indices = list(range(start_frame, end_frame))
|
||||
else:
|
||||
step = frame_sample_step or (end_frame - start_frame) // max_num_frames
|
||||
indices = list(range(start_frame, end_frame, step))
|
||||
|
||||
frames = video_reader.get_batch(indices=indices)
|
||||
frames = frames[:max_num_frames].float() # ensure that we don't go over the limit
|
||||
|
||||
# Choose first (4k + 1) frames as this is how many is required by the VAE
|
||||
selected_num_frames = frames.size(0)
|
||||
remainder = (3 + selected_num_frames) % 4
|
||||
if remainder != 0:
|
||||
frames = frames[:-remainder]
|
||||
assert frames.size(0) % 4 == 1
|
||||
|
||||
# Normalize the frames
|
||||
transform = T.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
|
||||
frames = torch.stack(tuple(map(transform, frames)), dim=0)
|
||||
|
||||
return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
||||
|
||||
|
||||
class CogVideoXDDIMInversionOutput:
|
||||
inverse_latents: torch.FloatTensor
|
||||
recon_latents: torch.FloatTensor
|
||||
|
||||
def __init__(self, inverse_latents: torch.FloatTensor, recon_latents: torch.FloatTensor):
|
||||
self.inverse_latents = inverse_latents
|
||||
self.recon_latents = recon_latents
|
||||
|
||||
|
||||
class CogVideoXPipelineForDDIMInversion(CogVideoXPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: T5Tokenizer,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKLCogVideoX,
|
||||
transformer: CogVideoXTransformer3DModel,
|
||||
scheduler: CogVideoXDDIMScheduler,
|
||||
):
|
||||
super().__init__(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.inverse_scheduler = DDIMInverseScheduler(**scheduler.config)
|
||||
|
||||
def encode_video_frames(self, video_frames: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
Encode video frames into latent space using Variational Autoencoder.
|
||||
|
||||
Args:
|
||||
video_frames (`torch.FloatTensor`):
|
||||
Input frames tensor in `[F, C, H, W]` format from `get_video_frames()`
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: Encoded latents in `[1, F, D, H_latent, W_latent]` format where:
|
||||
- `F`: Number of frames (same as input)
|
||||
- `D`: Latent channel dimension
|
||||
- `H_latent`: Latent space height (H // 2^vae.downscale_factor)
|
||||
- `W_latent`: Latent space width (W // 2^vae.downscale_factor)
|
||||
"""
|
||||
vae: AutoencoderKLCogVideoX = self.vae
|
||||
video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)
|
||||
video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)
|
||||
return latent_dist * vae.config.scaling_factor
|
||||
|
||||
@torch.no_grad()
|
||||
def export_latents_to_video(self, latents: torch.FloatTensor, video_path: str, fps: int):
|
||||
r"""
|
||||
Decode latent vectors into video and export as video file.
|
||||
|
||||
Args:
|
||||
latents (`torch.FloatTensor`): Encoded latents in `[B, F, D, H_latent, W_latent]` format from
|
||||
`encode_video_frames()`
|
||||
video_path (`str`): Output path for video file
|
||||
fps (`int`): Target frames per second for output video
|
||||
"""
|
||||
video = self.decode_latents(latents)
|
||||
frames = self.video_processor.postprocess_video(video=video, output_type="pil")
|
||||
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
||||
export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)
|
||||
|
||||
# Modified from CogVideoXPipeline.__call__
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
latents: torch.FloatTensor,
|
||||
scheduler: Union[DDIMInverseScheduler, CogVideoXDDIMScheduler],
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 6,
|
||||
use_dynamic_cfg: bool = False,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
reference_latents: torch.FloatTensor = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""
|
||||
Execute the core sampling loop for video generation/inversion using CogVideoX.
|
||||
|
||||
Implements the full denoising trajectory recording for both DDIM inversion and
|
||||
generation processes. Supports dynamic classifier-free guidance and reference
|
||||
latent conditioning.
|
||||
|
||||
Args:
|
||||
latents (`torch.FloatTensor`):
|
||||
Initial noise tensor of shape `[B, F, C, H, W]`.
|
||||
scheduler (`Union[DDIMInverseScheduler, CogVideoXDDIMScheduler]`):
|
||||
Scheduling strategy for diffusion process. Use:
|
||||
(1) `DDIMInverseScheduler` for inversion
|
||||
(2) `CogVideoXDDIMScheduler` for generation
|
||||
prompt (`Optional[Union[str, List[str]]]`):
|
||||
Text prompt(s) for conditional generation. Defaults to unconditional.
|
||||
negative_prompt (`Optional[Union[str, List[str]]]`):
|
||||
Negative prompt(s) for guidance. Requires `guidance_scale > 1`.
|
||||
num_inference_steps (`int`):
|
||||
Number of denoising steps. Affects quality/compute trade-off.
|
||||
guidance_scale (`float`):
|
||||
Classifier-free guidance weight. 1.0 = no guidance.
|
||||
use_dynamic_cfg (`bool`):
|
||||
Enable time-varying guidance scale (cosine schedule)
|
||||
eta (`float`):
|
||||
DDIM variance parameter (0 = deterministic process)
|
||||
generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`):
|
||||
Random number generator(s) for reproducibility
|
||||
attention_kwargs (`Optional[Dict[str, Any]]`):
|
||||
Custom parameters for attention modules
|
||||
reference_latents (`torch.FloatTensor`):
|
||||
Reference latent trajectory for conditional sampling. Shape should match
|
||||
`[T, B, F, C, H, W]` where `T` is number of timesteps
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
Full denoising trajectory tensor of shape `[T, B, F, C, H, W]`.
|
||||
"""
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
do_classifier_free_guidance,
|
||||
device=device,
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
if reference_latents is not None:
|
||||
prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latents = latents.to(device=device) * scheduler.init_noise_sigma
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs
|
||||
extra_step_kwargs = {}
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(
|
||||
height=latents.size(3) * self.vae_scale_factor_spatial,
|
||||
width=latents.size(4) * self.vae_scale_factor_spatial,
|
||||
num_frames=latents.size(1),
|
||||
device=device,
|
||||
)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
|
||||
|
||||
trajectory = torch.zeros_like(latents).unsqueeze(0).repeat(len(timesteps), 1, 1, 1, 1, 1)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if reference_latents is not None:
|
||||
reference = reference_latents[i]
|
||||
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
|
||||
latent_model_input = torch.cat([latent_model_input, reference], dim=0)
|
||||
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
|
||||
if reference_latents is not None: # Recover the original batch size
|
||||
noise_pred, _ = noise_pred.chunk(2)
|
||||
|
||||
# perform guidance
|
||||
if use_dynamic_cfg:
|
||||
self._guidance_scale = 1 + guidance_scale * (
|
||||
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
||||
)
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the noisy sample x_t-1 -> x_t
|
||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
latents = latents.to(prompt_embeds.dtype)
|
||||
trajectory[i] = latents
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
return trajectory
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
video_path: str,
|
||||
guidance_scale: float,
|
||||
num_inference_steps: int,
|
||||
skip_frames_start: int,
|
||||
skip_frames_end: int,
|
||||
frame_sample_step: Optional[int],
|
||||
max_num_frames: int,
|
||||
width: int,
|
||||
height: int,
|
||||
seed: int,
|
||||
):
|
||||
"""
|
||||
Performs DDIM inversion on a video to reconstruct it with a new prompt.
|
||||
|
||||
Args:
|
||||
prompt (`str`): The text prompt to guide the reconstruction.
|
||||
video_path (`str`): Path to the input video file.
|
||||
guidance_scale (`float`): Scale for classifier-free guidance.
|
||||
num_inference_steps (`int`): Number of denoising steps.
|
||||
skip_frames_start (`int`): Number of frames to skip from the beginning of the video.
|
||||
skip_frames_end (`int`): Number of frames to skip from the end of the video.
|
||||
frame_sample_step (`Optional[int]`): Step size for sampling frames. If None, all frames are used.
|
||||
max_num_frames (`int`): Maximum number of frames to process.
|
||||
width (`int`): Width of the output video frames.
|
||||
height (`int`): Height of the output video frames.
|
||||
seed (`int`): Random seed for reproducibility.
|
||||
|
||||
Returns:
|
||||
`CogVideoXDDIMInversionOutput`: Contains the inverse latents and reconstructed latents.
|
||||
"""
|
||||
if not self.transformer.config.use_rotary_positional_embeddings:
|
||||
raise NotImplementedError("This script supports CogVideoX 5B model only.")
|
||||
video_frames = get_video_frames(
|
||||
video_path=video_path,
|
||||
width=width,
|
||||
height=height,
|
||||
skip_frames_start=skip_frames_start,
|
||||
skip_frames_end=skip_frames_end,
|
||||
max_num_frames=max_num_frames,
|
||||
frame_sample_step=frame_sample_step,
|
||||
).to(device=self.device)
|
||||
video_latents = self.encode_video_frames(video_frames=video_frames)
|
||||
inverse_latents = self.sample(
|
||||
latents=video_latents,
|
||||
scheduler=self.inverse_scheduler,
|
||||
prompt="",
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device=self.device).manual_seed(seed),
|
||||
)
|
||||
with OverrideAttnProcessors(transformer=self.transformer):
|
||||
recon_latents = self.sample(
|
||||
latents=torch.randn_like(video_latents),
|
||||
scheduler=self.scheduler,
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=torch.Generator(device=self.device).manual_seed(seed),
|
||||
reference_latents=reversed(inverse_latents),
|
||||
)
|
||||
return CogVideoXDDIMInversionOutput(
|
||||
inverse_latents=inverse_latents,
|
||||
recon_latents=recon_latents,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arguments = get_args()
|
||||
pipeline = CogVideoXPipelineForDDIMInversion.from_pretrained(
|
||||
arguments.pop("model_path"),
|
||||
torch_dtype=arguments.pop("dtype"),
|
||||
).to(device=arguments.pop("device"))
|
||||
|
||||
output_path = arguments.pop("output_path")
|
||||
fps = arguments.pop("fps")
|
||||
inverse_video_path = os.path.join(output_path, f"{arguments.get('video_path')}_inversion.mp4")
|
||||
recon_video_path = os.path.join(output_path, f"{arguments.get('video_path')}_reconstruction.mp4")
|
||||
|
||||
# Run DDIM inversion
|
||||
output = pipeline(**arguments)
|
||||
pipeline.export_latents_to_video(output.inverse_latents[-1], inverse_video_path, fps)
|
||||
pipeline.export_latents_to_video(output.recon_latents[-1], recon_video_path, fps)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1334,9 +1334,7 @@ def main(args):
|
||||
|
||||
# run inference
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
generator = (
|
||||
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -172,7 +172,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
|
||||
if args.validation_images is None:
|
||||
images = []
|
||||
@@ -1119,22 +1119,17 @@ def main(args):
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1151,15 +1146,8 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
|
||||
with autocast_ctx:
|
||||
|
||||
@@ -170,7 +170,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
|
||||
|
||||
@@ -199,7 +199,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
|
||||
@@ -203,17 +203,17 @@ def log_validation(
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
|
||||
@@ -175,7 +175,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
|
||||
@@ -6,4 +6,4 @@ torch==2.2.0
|
||||
torchvision>=0.16
|
||||
ftfy==6.1.1
|
||||
tensorboard==2.14.0
|
||||
Jinja2==3.1.6
|
||||
Jinja2==3.1.5
|
||||
|
||||
@@ -137,7 +137,7 @@ def log_validation(
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
@@ -1241,11 +1241,7 @@ def main(args):
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = (
|
||||
torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
if args.seed is not None
|
||||
else None
|
||||
)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
with autocast_ctx:
|
||||
@@ -1309,9 +1305,7 @@ def main(args):
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = (
|
||||
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
|
||||
with autocast_ctx:
|
||||
images = [
|
||||
|
||||
@@ -3,19 +3,11 @@ from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
LlavaForConditionalGeneration,
|
||||
)
|
||||
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
)
|
||||
@@ -142,46 +134,6 @@ VAE_KEYS_RENAME_DICT = {}
|
||||
VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
TRANSFORMER_CONFIGS = {
|
||||
"HYVideo-T/2-cfgdistill": {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 24,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 20,
|
||||
"num_single_layers": 40,
|
||||
"num_refiner_layers": 2,
|
||||
"mlp_ratio": 4.0,
|
||||
"patch_size": 2,
|
||||
"patch_size_t": 1,
|
||||
"qk_norm": "rms_norm",
|
||||
"guidance_embeds": True,
|
||||
"text_embed_dim": 4096,
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
},
|
||||
"HYVideo-T/2-I2V": {
|
||||
"in_channels": 16 * 2 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 24,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 20,
|
||||
"num_single_layers": 40,
|
||||
"num_refiner_layers": 2,
|
||||
"mlp_ratio": 4.0,
|
||||
"patch_size": 2,
|
||||
"patch_size_t": 1,
|
||||
"qk_norm": "rms_norm",
|
||||
"guidance_embeds": False,
|
||||
"text_embed_dim": 4096,
|
||||
"pooled_projection_dim": 768,
|
||||
"rope_theta": 256.0,
|
||||
"rope_axes_dim": (16, 56, 56),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
@@ -197,12 +149,11 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_transformer(ckpt_path: str, transformer_type: str):
|
||||
def convert_transformer(ckpt_path: str):
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
||||
config = TRANSFORMER_CONFIGS[transformer_type]
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = HunyuanVideoTransformer3DModel(**config)
|
||||
transformer = HunyuanVideoTransformer3DModel()
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -254,10 +205,6 @@ def get_args():
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||
parser.add_argument(
|
||||
"--transformer_type", type=str, default="HYVideo-T/2-cfgdistill", choices=list(TRANSFORMER_CONFIGS.keys())
|
||||
)
|
||||
parser.add_argument("--flow_shift", type=float, default=7.0)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -281,7 +228,7 @@ if __name__ == "__main__":
|
||||
assert args.text_encoder_2_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer = convert_transformer(args.transformer_ckpt_path, args.transformer_type)
|
||||
transformer = convert_transformer(args.transformer_ckpt_path)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
@@ -292,41 +239,19 @@ if __name__ == "__main__":
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.save_pipeline:
|
||||
if args.transformer_type == "HYVideo-T/2-cfgdistill":
|
||||
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
|
||||
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
|
||||
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
|
||||
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
|
||||
pipe = HunyuanVideoPipeline(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
else:
|
||||
text_encoder = LlavaForConditionalGeneration.from_pretrained(
|
||||
args.text_encoder_path, torch_dtype=torch.float16
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
|
||||
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
|
||||
image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_path)
|
||||
|
||||
pipe = HunyuanVideoImageToVideoPipeline(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
scheduler=scheduler,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
pipe = HunyuanVideoPipeline(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
@@ -1,423 +0,0 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
UniPCMultistepScheduler,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
||||
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
||||
"time_projection.1": "condition_embedder.time_proj",
|
||||
"head.modulation": "scale_shift_table",
|
||||
"head.head": "proj_out",
|
||||
"modulation": "scale_shift_table",
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
# For the I2V model
|
||||
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def load_sharded_safetensors(dir: pathlib.Path):
|
||||
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
|
||||
state_dict = {}
|
||||
for path in file_paths:
|
||||
state_dict.update(load_file(path))
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
if model_type == "Wan-T2V-1.3B":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 12,
|
||||
"num_layers": 30,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
elif model_type == "Wan-T2V-14B":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
elif model_type == "Wan-I2V-14B-480p":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
|
||||
"diffusers_config": {
|
||||
"image_dim": 1280,
|
||||
"added_kv_proj_dim": 5120,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
elif model_type == "Wan-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
|
||||
"diffusers_config": {
|
||||
"image_dim": 1280,
|
||||
"added_kv_proj_dim": 5120,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def convert_transformer(model_type: str):
|
||||
config = get_transformer_config(model_type)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
model_id = config["model_id"]
|
||||
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
|
||||
|
||||
original_state_dict = load_sharded_safetensors(model_dir)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae():
|
||||
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
|
||||
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
|
||||
new_state_dict = {}
|
||||
|
||||
# Create mappings for specific components
|
||||
middle_key_mapping = {
|
||||
# Encoder middle block
|
||||
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
|
||||
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
|
||||
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
|
||||
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
|
||||
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
|
||||
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
|
||||
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
|
||||
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
|
||||
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
|
||||
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
|
||||
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
|
||||
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
|
||||
# Decoder middle block
|
||||
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
|
||||
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
|
||||
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
|
||||
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
|
||||
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
|
||||
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
|
||||
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
|
||||
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
|
||||
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
|
||||
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
|
||||
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
|
||||
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
|
||||
}
|
||||
|
||||
# Create a mapping for attention blocks
|
||||
attention_mapping = {
|
||||
# Encoder middle attention
|
||||
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
|
||||
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
|
||||
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
|
||||
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
|
||||
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
|
||||
# Decoder middle attention
|
||||
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
|
||||
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
|
||||
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
|
||||
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
|
||||
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
|
||||
}
|
||||
|
||||
# Create a mapping for the head components
|
||||
head_mapping = {
|
||||
# Encoder head
|
||||
"encoder.head.0.gamma": "encoder.norm_out.gamma",
|
||||
"encoder.head.2.bias": "encoder.conv_out.bias",
|
||||
"encoder.head.2.weight": "encoder.conv_out.weight",
|
||||
# Decoder head
|
||||
"decoder.head.0.gamma": "decoder.norm_out.gamma",
|
||||
"decoder.head.2.bias": "decoder.conv_out.bias",
|
||||
"decoder.head.2.weight": "decoder.conv_out.weight",
|
||||
}
|
||||
|
||||
# Create a mapping for the quant components
|
||||
quant_mapping = {
|
||||
"conv1.weight": "quant_conv.weight",
|
||||
"conv1.bias": "quant_conv.bias",
|
||||
"conv2.weight": "post_quant_conv.weight",
|
||||
"conv2.bias": "post_quant_conv.bias",
|
||||
}
|
||||
|
||||
# Process each key in the state dict
|
||||
for key, value in old_state_dict.items():
|
||||
# Handle middle block keys using the mapping
|
||||
if key in middle_key_mapping:
|
||||
new_key = middle_key_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle attention blocks using the mapping
|
||||
elif key in attention_mapping:
|
||||
new_key = attention_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle head keys using the mapping
|
||||
elif key in head_mapping:
|
||||
new_key = head_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle quant keys using the mapping
|
||||
elif key in quant_mapping:
|
||||
new_key = quant_mapping[key]
|
||||
new_state_dict[new_key] = value
|
||||
# Handle encoder conv1
|
||||
elif key == "encoder.conv1.weight":
|
||||
new_state_dict["encoder.conv_in.weight"] = value
|
||||
elif key == "encoder.conv1.bias":
|
||||
new_state_dict["encoder.conv_in.bias"] = value
|
||||
# Handle decoder conv1
|
||||
elif key == "decoder.conv1.weight":
|
||||
new_state_dict["decoder.conv_in.weight"] = value
|
||||
elif key == "decoder.conv1.bias":
|
||||
new_state_dict["decoder.conv_in.bias"] = value
|
||||
# Handle encoder downsamples
|
||||
elif key.startswith("encoder.downsamples."):
|
||||
# Convert to down_blocks
|
||||
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
|
||||
|
||||
# Convert residual block naming but keep the original structure
|
||||
if ".residual.0.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
||||
elif ".residual.2.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
||||
elif ".residual.2.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
||||
elif ".residual.3.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
||||
elif ".residual.6.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
||||
elif ".residual.6.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
||||
elif ".shortcut.bias" in new_key:
|
||||
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
||||
elif ".shortcut.weight" in new_key:
|
||||
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
# Handle decoder upsamples
|
||||
elif key.startswith("decoder.upsamples."):
|
||||
# Convert to up_blocks
|
||||
parts = key.split(".")
|
||||
block_idx = int(parts[2])
|
||||
|
||||
# Group residual blocks
|
||||
if "residual" in key:
|
||||
if block_idx in [0, 1, 2]:
|
||||
new_block_idx = 0
|
||||
resnet_idx = block_idx
|
||||
elif block_idx in [4, 5, 6]:
|
||||
new_block_idx = 1
|
||||
resnet_idx = block_idx - 4
|
||||
elif block_idx in [8, 9, 10]:
|
||||
new_block_idx = 2
|
||||
resnet_idx = block_idx - 8
|
||||
elif block_idx in [12, 13, 14]:
|
||||
new_block_idx = 3
|
||||
resnet_idx = block_idx - 12
|
||||
else:
|
||||
# Keep as is for other blocks
|
||||
new_state_dict[key] = value
|
||||
continue
|
||||
|
||||
# Convert residual block naming
|
||||
if ".residual.0.gamma" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
|
||||
elif ".residual.2.bias" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
|
||||
elif ".residual.2.weight" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
|
||||
elif ".residual.3.gamma" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
|
||||
elif ".residual.6.bias" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
|
||||
elif ".residual.6.weight" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
# Handle shortcut connections
|
||||
elif ".shortcut." in key:
|
||||
if block_idx == 4:
|
||||
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
|
||||
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
# Handle upsamplers
|
||||
elif ".resample." in key or ".time_conv." in key:
|
||||
if block_idx == 3:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
|
||||
elif block_idx == 7:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
|
||||
elif block_idx == 11:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
new_state_dict[new_key] = value
|
||||
else:
|
||||
# Keep other keys unchanged
|
||||
new_state_dict[key] = value
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLWan()
|
||||
vae.load_state_dict(new_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--dtype", default="fp32")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
transformer = convert_transformer(args.model_type).to(dtype=dtype)
|
||||
vae = convert_vae()
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
|
||||
)
|
||||
|
||||
if "I2V" in args.model_type:
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
)
|
||||
image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||
pipe = WanImageToVideoPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
else:
|
||||
pipe = WanPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
@@ -94,10 +94,8 @@ else:
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderKLWan",
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"CacheMixin",
|
||||
@@ -110,7 +108,6 @@ else:
|
||||
"ControlNetUnionModel",
|
||||
"ControlNetXSAdapter",
|
||||
"DiTTransformer2DModel",
|
||||
"EasyAnimateTransformer3DModel",
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
@@ -151,7 +148,6 @@ else:
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"UVit2DModel",
|
||||
"VQModel",
|
||||
"WanTransformer3DModel",
|
||||
]
|
||||
)
|
||||
_import_structure["optimization"] = [
|
||||
@@ -295,9 +291,6 @@ else:
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
@@ -313,7 +306,6 @@ else:
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
"HunyuanSkyreelsImageToVideoPipeline",
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
"HunyuanVideoPipeline",
|
||||
"I2VGenXLPipeline",
|
||||
"IFImg2ImgPipeline",
|
||||
@@ -353,7 +345,6 @@ else:
|
||||
"Lumina2Text2ImgPipeline",
|
||||
"LuminaText2ImgPipeline",
|
||||
"MarigoldDepthPipeline",
|
||||
"MarigoldIntrinsicsPipeline",
|
||||
"MarigoldNormalsPipeline",
|
||||
"MochiPipeline",
|
||||
"MusicLDMPipeline",
|
||||
@@ -446,8 +437,6 @@ else:
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
"VideoToVideoSDPipeline",
|
||||
"VQDiffusionPipeline",
|
||||
"WanImageToVideoPipeline",
|
||||
"WanPipeline",
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
@@ -626,10 +615,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
CacheMixin,
|
||||
@@ -642,7 +629,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
DiTTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
@@ -682,7 +668,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
VQModel,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
@@ -806,9 +791,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
@@ -824,7 +806,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiTPAGPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
I2VGenXLPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
@@ -864,7 +845,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Lumina2Text2ImgPipeline,
|
||||
LuminaText2ImgPipeline,
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldIntrinsicsPipeline,
|
||||
MarigoldNormalsPipeline,
|
||||
MochiPipeline,
|
||||
MusicLDMPipeline,
|
||||
@@ -885,7 +865,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
StableDiffusion3ControlNetInpaintingPipeline,
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
@@ -956,8 +935,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
VQDiffusionPipeline,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ..utils import get_logger
|
||||
from ._common import _BATCHED_INPUT_IDENTIFIERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CFG_PARALLEL = "cfg_parallel"
|
||||
|
||||
|
||||
class CFGParallelHook(ModelHook):
|
||||
def initialize_hook(self, module):
|
||||
if not dist.is_initialized():
|
||||
raise RuntimeError("Distributed environment not initialized.")
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if len(args) > 0:
|
||||
logger.warning(
|
||||
"CFGParallelHook is an example hook that does not work with batched positional arguments. Please use with caution."
|
||||
)
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
|
||||
assert world_size == 2, "This is an example hook designed to only work with 2 processes."
|
||||
|
||||
for key in list(kwargs.keys()):
|
||||
if key not in _BATCHED_INPUT_IDENTIFIERS or kwargs[key] is None:
|
||||
continue
|
||||
kwargs[key] = torch.chunk(kwargs[key], world_size, dim=0)[rank].contiguous()
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
sample = output[0]
|
||||
sample_list = [torch.empty_like(sample) for _ in range(world_size)]
|
||||
dist.all_gather(sample_list, sample)
|
||||
sample = torch.cat(sample_list, dim=0).contiguous()
|
||||
|
||||
return_dict = kwargs.get("return_dict", False)
|
||||
if not return_dict:
|
||||
return (sample, *output[1:])
|
||||
return output.__class__(sample, *output[1:])
|
||||
|
||||
|
||||
def apply_cfg_parallel(module: torch.nn.Module) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = CFGParallelHook()
|
||||
registry.register_hook(hook, _CFG_PARALLEL)
|
||||
@@ -0,0 +1,26 @@
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
{
|
||||
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
|
||||
_BATCHED_INPUT_IDENTIFIERS = (
|
||||
"hidden_states",
|
||||
"encoder_hidden_states",
|
||||
"pooled_projections",
|
||||
"timestep",
|
||||
"attention_mask",
|
||||
"encoder_attention_mask",
|
||||
"guidance",
|
||||
)
|
||||
@@ -20,19 +20,18 @@ import torch
|
||||
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..utils import logging
|
||||
from ._common import (
|
||||
_ATTENTION_CLASSES,
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
)
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PyramidAttentionBroadcastConfig:
|
||||
r"""
|
||||
@@ -76,9 +75,9 @@ class PyramidAttentionBroadcastConfig:
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
|
||||
|
||||
@@ -74,7 +74,6 @@ if is_torch_available():
|
||||
"HunyuanVideoLoraLoaderMixin",
|
||||
"SanaLoraLoaderMixin",
|
||||
"Lumina2LoraLoaderMixin",
|
||||
"WanLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
@@ -113,7 +112,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SD3LoraLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
WanLoraLoaderMixin,
|
||||
)
|
||||
from .single_file import FromSingleFileMixin
|
||||
from .textual_inversion import TextualInversionLoaderMixin
|
||||
|
||||
@@ -23,9 +23,7 @@ from safetensors import safe_open
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_detailed_type,
|
||||
_get_model_file,
|
||||
_is_valid_type,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
@@ -215,8 +213,7 @@ class IPAdapterMixin:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
torch_dtype=self.dtype,
|
||||
).to(self.device)
|
||||
).to(self.device, dtype=self.dtype)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -527,9 +524,8 @@ class FluxIPAdapterMixin:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
dtype=image_encoder_dtype,
|
||||
)
|
||||
.to(self.device)
|
||||
.to(self.device, dtype=image_encoder_dtype)
|
||||
.eval()
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
@@ -581,36 +577,29 @@ class FluxIPAdapterMixin:
|
||||
pipeline.set_ip_adapter_scale(ip_strengths)
|
||||
```
|
||||
"""
|
||||
|
||||
scale_type = Union[int, float]
|
||||
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
num_layers = self.transformer.config.num_layers
|
||||
|
||||
# Single value for all layers of all IP-Adapters
|
||||
if isinstance(scale, scale_type):
|
||||
scale = [scale for _ in range(num_ip_adapters)]
|
||||
# List of per-layer scales for a single IP-Adapter
|
||||
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
|
||||
transformer = self.transformer
|
||||
if not isinstance(scale, list):
|
||||
scale = [[scale] * transformer.config.num_layers]
|
||||
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
|
||||
if len(scale) != transformer.config.num_layers:
|
||||
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
|
||||
scale = [scale]
|
||||
# Invalid scale type
|
||||
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
|
||||
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
|
||||
|
||||
if len(scale) != num_ip_adapters:
|
||||
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
|
||||
scale_configs = scale
|
||||
|
||||
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
|
||||
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
|
||||
raise ValueError(
|
||||
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
|
||||
)
|
||||
|
||||
# Scalars are transformed to lists with length num_layers
|
||||
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
|
||||
|
||||
# Set scales. zip over scale_configs prevents going into single transformer layers
|
||||
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
|
||||
attn_processor.scale = scale
|
||||
key_id = 0
|
||||
for attn_name, attn_processor in transformer.attn_processors.items():
|
||||
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
f"{len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
attn_processor.scale[i] = scale_config[key_id]
|
||||
key_id += 1
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
@@ -807,9 +796,9 @@ class SD3IPAdapterMixin:
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(
|
||||
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
||||
).to(self.device),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -654,7 +654,6 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
|
||||
_convert(k, diffusers_key, state_dict, new_state_dict)
|
||||
|
||||
remaining_all_unet = False
|
||||
if state_dict:
|
||||
remaining_all_unet = all(k.startswith("lora_unet_") for k in state_dict)
|
||||
if remaining_all_unet:
|
||||
@@ -1277,74 +1276,3 @@ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
|
||||
# Remove "diffusion_model." prefix from keys.
|
||||
state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
||||
converted_state_dict = {}
|
||||
|
||||
def get_num_layers(keys, pattern):
|
||||
layers = set()
|
||||
for key in keys:
|
||||
match = re.search(pattern, key)
|
||||
if match:
|
||||
layers.add(int(match.group(1)))
|
||||
return len(layers)
|
||||
|
||||
def process_block(prefix, index, convert_norm):
|
||||
# Process attention qkv: pop lora_A and lora_B weights.
|
||||
lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
|
||||
lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
|
||||
for attn_key in ["to_q", "to_k", "to_v"]:
|
||||
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
|
||||
for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
|
||||
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
|
||||
|
||||
# Process attention out weights.
|
||||
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.attention.out.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.attention.out.lora_B.weight"
|
||||
)
|
||||
|
||||
# Process feed-forward weights for layers 1, 2, and 3.
|
||||
for layer in range(1, 4):
|
||||
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
|
||||
)
|
||||
|
||||
if convert_norm:
|
||||
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
|
||||
)
|
||||
|
||||
noise_refiner_pattern = r"noise_refiner\.(\d+)\."
|
||||
num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
|
||||
for i in range(num_noise_refiner_layers):
|
||||
process_block("noise_refiner", i, convert_norm=True)
|
||||
|
||||
context_refiner_pattern = r"context_refiner\.(\d+)\."
|
||||
num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
|
||||
for i in range(num_context_refiner_layers):
|
||||
process_block("context_refiner", i, convert_norm=False)
|
||||
|
||||
core_transformer_pattern = r"layers\.(\d+)\."
|
||||
num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
|
||||
for i in range(num_core_transformer_layers):
|
||||
process_block("layers", i, convert_norm=True)
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -41,7 +41,6 @@ from .lora_conversion_utils import (
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
)
|
||||
@@ -3816,6 +3815,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -3909,11 +3909,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
# conversion.
|
||||
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if non_diffusers:
|
||||
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
@@ -4115,311 +4110,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
||||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`WanTransformer3DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -53,7 +53,6 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"WanTransformer3DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
@@ -64,9 +63,6 @@ def _maybe_adjust_config(config):
|
||||
method removes the ambiguity by following what is described here:
|
||||
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
|
||||
"""
|
||||
# Track keys that have been explicitly removed to prevent re-adding them.
|
||||
deleted_keys = set()
|
||||
|
||||
rank_pattern = config["rank_pattern"].copy()
|
||||
target_modules = config["target_modules"]
|
||||
original_r = config["r"]
|
||||
@@ -84,22 +80,21 @@ def _maybe_adjust_config(config):
|
||||
ambiguous_key = key
|
||||
|
||||
if exact_matches and substring_matches:
|
||||
# if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
|
||||
# if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
|
||||
config["r"] = key_rank
|
||||
# remove the ambiguous key from `rank_pattern` and record it as deleted
|
||||
# remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
|
||||
del config["rank_pattern"][key]
|
||||
deleted_keys.add(key)
|
||||
# For substring matches, add them with the original rank only if they haven't been assigned already
|
||||
for mod in substring_matches:
|
||||
if mod not in config["rank_pattern"] and mod not in deleted_keys:
|
||||
# avoid overwriting if the module already has a specific rank
|
||||
if mod not in config["rank_pattern"]:
|
||||
config["rank_pattern"][mod] = original_r
|
||||
|
||||
# Update the rest of the target modules with the original rank if not already set and not deleted
|
||||
# update the rest of the keys with the `original_r`
|
||||
for mod in target_modules:
|
||||
if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
|
||||
if mod != ambiguous_key and mod not in config["rank_pattern"]:
|
||||
config["rank_pattern"][mod] = original_r
|
||||
|
||||
# Handle alphas to deal with cases like:
|
||||
# handle alphas to deal with cases like
|
||||
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
|
||||
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
|
||||
if has_different_ranks:
|
||||
@@ -192,11 +187,6 @@ class PeftAdapterMixin:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
try:
|
||||
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
|
||||
except ImportError:
|
||||
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -261,22 +251,14 @@ class PeftAdapterMixin:
|
||||
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
|
||||
# Bias layers in LoRA only have a single dimension
|
||||
if "lora_B" in key and val.ndim > 1:
|
||||
# Support to handle cases where layer patterns are treated as full layer names
|
||||
# was added later in PEFT. So, we handle it accordingly.
|
||||
# TODO: when we fix the minimal PEFT version for Diffusers,
|
||||
# we should remove `_maybe_adjust_config()`.
|
||||
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
|
||||
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
|
||||
else:
|
||||
rank[key] = val.shape[1]
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
|
||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
|
||||
@@ -360,17 +360,11 @@ class FromSingleFileMixin:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
is_legacy_loading = False
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
# We shouldn't allow configuring individual models components through a Pipeline creation method
|
||||
# These model kwargs should be deprecated
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
|
||||
@@ -39,8 +39,6 @@ from .single_file_utils import (
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
convert_wan_vae_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
create_vae_diffusers_config_from_ldm,
|
||||
@@ -119,14 +117,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLWan": {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -250,17 +240,11 @@ class FromOriginalModelMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
else:
|
||||
|
||||
@@ -117,8 +117,6 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
||||
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -178,9 +176,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
||||
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
|
||||
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -402,7 +397,6 @@ def load_single_file_checkpoint(
|
||||
|
||||
else:
|
||||
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
|
||||
user_agent = {"file_type": "single_file", "framework": "pytorch"}
|
||||
pretrained_model_link_or_path = _get_model_file(
|
||||
repo_id,
|
||||
weights_name=weights_name,
|
||||
@@ -412,7 +406,6 @@ def load_single_file_checkpoint(
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
|
||||
@@ -669,21 +662,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
|
||||
model_type = "lumina2"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
|
||||
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
|
||||
target_key = "model.diffusion_model.patch_embedding.weight"
|
||||
else:
|
||||
target_key = "patch_embedding.weight"
|
||||
|
||||
if checkpoint[target_key].shape[0] == 1536:
|
||||
model_type = "wan-t2v-1.3B"
|
||||
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
|
||||
model_type = "wan-t2v-14B"
|
||||
else:
|
||||
model_type = "wan-i2v-14B"
|
||||
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
||||
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
||||
model_type = "wan-t2v-14B"
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -1470,8 +1448,8 @@ def convert_open_clip_checkpoint(
|
||||
|
||||
if text_proj_key in checkpoint:
|
||||
text_proj_dim = int(checkpoint[text_proj_key].shape[0])
|
||||
elif hasattr(text_model.config, "hidden_size"):
|
||||
text_proj_dim = text_model.config.hidden_size
|
||||
elif hasattr(text_model.config, "projection_dim"):
|
||||
text_proj_dim = text_model.config.projection_dim
|
||||
else:
|
||||
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
|
||||
|
||||
@@ -2490,7 +2468,7 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
|
||||
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
new_state_dict = {}
|
||||
|
||||
# Comfy checkpoints add this prefix
|
||||
keys = list(checkpoint.keys())
|
||||
@@ -2499,22 +2477,22 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
# Convert patch_embed
|
||||
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Convert time_embed
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
|
||||
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
|
||||
converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
|
||||
converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
|
||||
converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
|
||||
converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
|
||||
converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
|
||||
converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
|
||||
converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
|
||||
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
|
||||
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
|
||||
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
|
||||
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
|
||||
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
|
||||
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
|
||||
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
|
||||
|
||||
# Convert transformer blocks
|
||||
num_layers = 48
|
||||
@@ -2523,84 +2501,68 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
old_prefix = f"blocks.{i}."
|
||||
|
||||
# norm1
|
||||
converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
|
||||
converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
|
||||
if i < num_layers - 1:
|
||||
converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.bias"
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
|
||||
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
||||
else:
|
||||
converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.bias"
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
||||
|
||||
# Visual attention
|
||||
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
converted_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
||||
converted_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
||||
converted_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
||||
converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.q_norm_x.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.k_norm_x.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.proj_x.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
|
||||
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
|
||||
|
||||
# Context attention
|
||||
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
||||
converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
||||
converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
||||
converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
|
||||
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.q_norm_y.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
|
||||
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.k_norm_y.weight"
|
||||
)
|
||||
if i < num_layers - 1:
|
||||
converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.proj_y.weight"
|
||||
)
|
||||
converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(
|
||||
old_prefix + "attn.proj_y.bias"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
|
||||
|
||||
# MLP
|
||||
converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
||||
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
|
||||
)
|
||||
converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
|
||||
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
|
||||
if i < num_layers - 1:
|
||||
converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
||||
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
||||
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
|
||||
)
|
||||
converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(
|
||||
old_prefix + "mlp_y.w2.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
|
||||
|
||||
# Output layers
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
|
||||
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
|
||||
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
|
||||
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
||||
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
||||
|
||||
return converted_state_dict
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
@@ -2895,252 +2857,3 @@ def convert_lumina2_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict[diffusers_key] = checkpoint.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
||||
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
||||
"time_projection.1": "condition_embedder.time_proj",
|
||||
"cross_attn": "attn2",
|
||||
"self_attn": "attn1",
|
||||
".o.": ".to_out.0.",
|
||||
".q.": ".to_q.",
|
||||
".k.": ".to_k.",
|
||||
".v.": ".to_v.",
|
||||
".k_img.": ".add_k_proj.",
|
||||
".v_img.": ".add_v_proj.",
|
||||
".norm_k_img.": ".norm_added_k.",
|
||||
"head.modulation": "scale_shift_table",
|
||||
"head.head": "proj_out",
|
||||
"modulation": "scale_shift_table",
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
# For the I2V model
|
||||
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
}
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
converted_state_dict[new_key] = checkpoint.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
|
||||
# Create mappings for specific components
|
||||
middle_key_mapping = {
|
||||
# Encoder middle block
|
||||
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
|
||||
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
|
||||
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
|
||||
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
|
||||
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
|
||||
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
|
||||
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
|
||||
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
|
||||
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
|
||||
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
|
||||
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
|
||||
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
|
||||
# Decoder middle block
|
||||
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
|
||||
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
|
||||
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
|
||||
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
|
||||
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
|
||||
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
|
||||
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
|
||||
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
|
||||
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
|
||||
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
|
||||
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
|
||||
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
|
||||
}
|
||||
|
||||
# Create a mapping for attention blocks
|
||||
attention_mapping = {
|
||||
# Encoder middle attention
|
||||
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
|
||||
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
|
||||
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
|
||||
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
|
||||
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
|
||||
# Decoder middle attention
|
||||
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
|
||||
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
|
||||
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
|
||||
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
|
||||
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
|
||||
}
|
||||
|
||||
# Create a mapping for the head components
|
||||
head_mapping = {
|
||||
# Encoder head
|
||||
"encoder.head.0.gamma": "encoder.norm_out.gamma",
|
||||
"encoder.head.2.bias": "encoder.conv_out.bias",
|
||||
"encoder.head.2.weight": "encoder.conv_out.weight",
|
||||
# Decoder head
|
||||
"decoder.head.0.gamma": "decoder.norm_out.gamma",
|
||||
"decoder.head.2.bias": "decoder.conv_out.bias",
|
||||
"decoder.head.2.weight": "decoder.conv_out.weight",
|
||||
}
|
||||
|
||||
# Create a mapping for the quant components
|
||||
quant_mapping = {
|
||||
"conv1.weight": "quant_conv.weight",
|
||||
"conv1.bias": "quant_conv.bias",
|
||||
"conv2.weight": "post_quant_conv.weight",
|
||||
"conv2.bias": "post_quant_conv.bias",
|
||||
}
|
||||
|
||||
# Process each key in the state dict
|
||||
for key, value in checkpoint.items():
|
||||
# Handle middle block keys using the mapping
|
||||
if key in middle_key_mapping:
|
||||
new_key = middle_key_mapping[key]
|
||||
converted_state_dict[new_key] = value
|
||||
# Handle attention blocks using the mapping
|
||||
elif key in attention_mapping:
|
||||
new_key = attention_mapping[key]
|
||||
converted_state_dict[new_key] = value
|
||||
# Handle head keys using the mapping
|
||||
elif key in head_mapping:
|
||||
new_key = head_mapping[key]
|
||||
converted_state_dict[new_key] = value
|
||||
# Handle quant keys using the mapping
|
||||
elif key in quant_mapping:
|
||||
new_key = quant_mapping[key]
|
||||
converted_state_dict[new_key] = value
|
||||
# Handle encoder conv1
|
||||
elif key == "encoder.conv1.weight":
|
||||
converted_state_dict["encoder.conv_in.weight"] = value
|
||||
elif key == "encoder.conv1.bias":
|
||||
converted_state_dict["encoder.conv_in.bias"] = value
|
||||
# Handle decoder conv1
|
||||
elif key == "decoder.conv1.weight":
|
||||
converted_state_dict["decoder.conv_in.weight"] = value
|
||||
elif key == "decoder.conv1.bias":
|
||||
converted_state_dict["decoder.conv_in.bias"] = value
|
||||
# Handle encoder downsamples
|
||||
elif key.startswith("encoder.downsamples."):
|
||||
# Convert to down_blocks
|
||||
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
|
||||
|
||||
# Convert residual block naming but keep the original structure
|
||||
if ".residual.0.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
||||
elif ".residual.2.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
|
||||
elif ".residual.2.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
|
||||
elif ".residual.3.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
|
||||
elif ".residual.6.bias" in new_key:
|
||||
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
|
||||
elif ".residual.6.weight" in new_key:
|
||||
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
|
||||
elif ".shortcut.bias" in new_key:
|
||||
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
||||
elif ".shortcut.weight" in new_key:
|
||||
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
# Handle decoder upsamples
|
||||
elif key.startswith("decoder.upsamples."):
|
||||
# Convert to up_blocks
|
||||
parts = key.split(".")
|
||||
block_idx = int(parts[2])
|
||||
|
||||
# Group residual blocks
|
||||
if "residual" in key:
|
||||
if block_idx in [0, 1, 2]:
|
||||
new_block_idx = 0
|
||||
resnet_idx = block_idx
|
||||
elif block_idx in [4, 5, 6]:
|
||||
new_block_idx = 1
|
||||
resnet_idx = block_idx - 4
|
||||
elif block_idx in [8, 9, 10]:
|
||||
new_block_idx = 2
|
||||
resnet_idx = block_idx - 8
|
||||
elif block_idx in [12, 13, 14]:
|
||||
new_block_idx = 3
|
||||
resnet_idx = block_idx - 12
|
||||
else:
|
||||
# Keep as is for other blocks
|
||||
converted_state_dict[key] = value
|
||||
continue
|
||||
|
||||
# Convert residual block naming
|
||||
if ".residual.0.gamma" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
|
||||
elif ".residual.2.bias" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
|
||||
elif ".residual.2.weight" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
|
||||
elif ".residual.3.gamma" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
|
||||
elif ".residual.6.bias" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
|
||||
elif ".residual.6.weight" in key:
|
||||
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
# Handle shortcut connections
|
||||
elif ".shortcut." in key:
|
||||
if block_idx == 4:
|
||||
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
|
||||
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
# Handle upsamplers
|
||||
elif ".resample." in key or ".time_conv." in key:
|
||||
if block_idx == 3:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
|
||||
elif block_idx == 7:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
|
||||
elif block_idx == 11:
|
||||
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
else:
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
converted_state_dict[new_key] = value
|
||||
else:
|
||||
# Keep other keys unchanged
|
||||
converted_state_dict[key] = value
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
Executable → Regular
-8
@@ -33,10 +33,8 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
|
||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
@@ -73,7 +71,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
@@ -82,7 +79,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -111,10 +107,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
@@ -147,7 +141,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsisIDTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
FluxTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
@@ -165,7 +158,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from .unets import (
|
||||
I2VGenXLUNet,
|
||||
|
||||
Executable → Regular
+10
-21
@@ -213,9 +213,7 @@ class Attention(nn.Module):
|
||||
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
|
||||
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
|
||||
)
|
||||
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
|
||||
|
||||
if cross_attention_norm is None:
|
||||
self.norm_cross = None
|
||||
@@ -274,20 +272,12 @@ class Attention(nn.Module):
|
||||
self.to_add_out = None
|
||||
|
||||
if qk_norm is not None and added_kv_proj_dim is not None:
|
||||
if qk_norm == "layer_norm":
|
||||
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
elif qk_norm == "fp32_layer_norm":
|
||||
if qk_norm == "fp32_layer_norm":
|
||||
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
elif qk_norm == "rms_norm":
|
||||
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
||||
elif qk_norm == "rms_norm_across_heads":
|
||||
# Wan applies qk norm across all heads
|
||||
# Wan also doesn't apply a q norm
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
|
||||
@@ -1418,7 +1408,7 @@ class JointAttnProcessor2_0:
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -2788,8 +2778,9 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
|
||||
# IP-adapter
|
||||
ip_query = hidden_states_query_proj
|
||||
ip_attn_output = torch.zeros_like(hidden_states)
|
||||
|
||||
ip_attn_output = None
|
||||
# for ip-adapter
|
||||
# TODO: support for multiple adapters
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
@@ -2800,14 +2791,12 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
ip_attn_output = F.scaled_dot_product_attention(
|
||||
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||
ip_attn_output += scale * current_ip_hidden_states
|
||||
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_attn_output = scale * ip_attn_output
|
||||
ip_attn_output = ip_attn_output.to(ip_query.dtype)
|
||||
|
||||
return hidden_states, encoder_hidden_states, ip_attn_output
|
||||
else:
|
||||
|
||||
@@ -5,10 +5,8 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_kl_wan import AutoencoderKLWan
|
||||
from .autoencoder_oobleck import AutoencoderOobleck
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,866 +0,0 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
class WanCausalConv3d(nn.Conv3d):
|
||||
r"""
|
||||
A custom 3D causal convolution layer with feature caching support.
|
||||
|
||||
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
||||
caching for efficient inference.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input image
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
||||
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Union[int, Tuple[int, int, int]] = 1,
|
||||
padding: Union[int, Tuple[int, int, int]] = 0,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
# Set up causal padding
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class WanRMS_norm(nn.Module):
|
||||
r"""
|
||||
A custom RMS normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The number of dimensions to normalize over.
|
||||
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
||||
Default is True.
|
||||
images (bool, optional): Whether the input represents image data. Default is True.
|
||||
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class WanUpsample(nn.Upsample):
|
||||
r"""
|
||||
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor to be upsampled.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Upsampled tensor with the same data type as the input.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class WanResample(nn.Module):
|
||||
r"""
|
||||
A custom resampling module for 2D and 3D data.
|
||||
|
||||
Args:
|
||||
dim (int): The number of input/output channels.
|
||||
mode (str): The resampling mode. Must be one of:
|
||||
- 'none': No resampling (identity operation).
|
||||
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
|
||||
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
|
||||
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
||||
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, mode: str) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == "upsample2d":
|
||||
self.resample = nn.Sequential(
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
elif mode == "upsample3d":
|
||||
self.resample = nn.Sequential(
|
||||
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
|
||||
)
|
||||
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
elif mode == "downsample2d":
|
||||
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == "downsample3d":
|
||||
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == "upsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = "Rep"
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat(
|
||||
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
||||
)
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
||||
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
||||
if feat_cache[idx] == "Rep":
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.resample(x)
|
||||
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
||||
|
||||
if self.mode == "downsample3d":
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
|
||||
class WanResidualBlock(nn.Module):
|
||||
r"""
|
||||
A custom residual block module.
|
||||
|
||||
Args:
|
||||
in_dim (int): Number of input channels.
|
||||
out_dim (int): Number of output channels.
|
||||
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
||||
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
dropout: float = 0.0,
|
||||
non_linearity: str = "silu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
# layers
|
||||
self.norm1 = WanRMS_norm(in_dim, images=False)
|
||||
self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
|
||||
self.norm2 = WanRMS_norm(out_dim, images=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
|
||||
self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# Apply shortcut connection
|
||||
h = self.conv_shortcut(x)
|
||||
|
||||
# First normalization and activation
|
||||
x = self.norm1(x)
|
||||
x = self.nonlinearity(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
# Second normalization and activation
|
||||
x = self.norm2(x)
|
||||
x = self.nonlinearity(x)
|
||||
|
||||
# Dropout
|
||||
x = self.dropout(x)
|
||||
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
|
||||
x = self.conv2(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv2(x)
|
||||
|
||||
# Add residual connection
|
||||
return x + h
|
||||
|
||||
|
||||
class WanAttentionBlock(nn.Module):
|
||||
r"""
|
||||
Causal self-attention with a single head.
|
||||
|
||||
Args:
|
||||
dim (int): The number of channels in the input tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = WanRMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
batch_size, channels, time, height, width = x.size()
|
||||
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
||||
x = self.norm(x)
|
||||
|
||||
# compute query, key, value
|
||||
qkv = self.to_qkv(x)
|
||||
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
||||
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
||||
|
||||
# output projection
|
||||
x = self.proj(x)
|
||||
|
||||
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
|
||||
x = x.view(batch_size, time, channels, height, width)
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return x + identity
|
||||
|
||||
|
||||
class WanMidBlock(nn.Module):
|
||||
"""
|
||||
Middle block for WanVAE encoder and decoder.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input/output channels.
|
||||
dropout (float): Dropout rate.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# Create the components
|
||||
resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
|
||||
attentions = []
|
||||
for _ in range(num_layers):
|
||||
attentions.append(WanAttentionBlock(dim))
|
||||
resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
# First residual block
|
||||
x = self.resnets[0](x, feat_cache, feat_idx)
|
||||
|
||||
# Process through attention and residual blocks
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
x = attn(x)
|
||||
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WanEncoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D encoder module.
|
||||
|
||||
Args:
|
||||
dim (int): The base number of channels in the first layer.
|
||||
z_dim (int): The dimensionality of the latent space.
|
||||
dim_mult (list of int): Multipliers for the number of channels in each block.
|
||||
num_res_blocks (int): Number of residual blocks in each block.
|
||||
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
||||
temperal_downsample (list of bool): Whether to downsample temporally in each block.
|
||||
dropout (float): Dropout rate for the dropout layers.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
self.down_blocks.append(WanAttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
||||
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
||||
|
||||
# output blocks
|
||||
self.norm_out = WanRMS_norm(out_dim, images=False)
|
||||
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_in(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_in(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.down_blocks:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
x = self.nonlinearity(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_out(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class WanUpBlock(nn.Module):
|
||||
"""
|
||||
A block that handles upsampling for the WanVAE decoder.
|
||||
|
||||
Args:
|
||||
in_dim (int): Input dimension
|
||||
out_dim (int): Output dimension
|
||||
num_res_blocks (int): Number of residual blocks
|
||||
dropout (float): Dropout rate
|
||||
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
||||
non_linearity (str): Type of non-linearity to use
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
num_res_blocks: int,
|
||||
dropout: float = 0.0,
|
||||
upsample_mode: Optional[str] = None,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# Create layers list
|
||||
resnets = []
|
||||
# Add residual blocks and attention if needed
|
||||
current_dim = in_dim
|
||||
for _ in range(num_res_blocks + 1):
|
||||
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
||||
current_dim = out_dim
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
# Add upsampling layer if needed
|
||||
self.upsamplers = None
|
||||
if upsample_mode is not None:
|
||||
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
"""
|
||||
Forward pass through the upsampling block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor
|
||||
feat_cache (list, optional): Feature cache for causal convolutions
|
||||
feat_idx (list, optional): Feature index for cache management
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor
|
||||
"""
|
||||
for resnet in self.resnets:
|
||||
if feat_cache is not None:
|
||||
x = resnet(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = resnet(x)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
if feat_cache is not None:
|
||||
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = self.upsamplers[0](x)
|
||||
return x
|
||||
|
||||
|
||||
class WanDecoder3d(nn.Module):
|
||||
r"""
|
||||
A 3D decoder module.
|
||||
|
||||
Args:
|
||||
dim (int): The base number of channels in the first layer.
|
||||
z_dim (int): The dimensionality of the latent space.
|
||||
dim_mult (list of int): Multipliers for the number of channels in each block.
|
||||
num_res_blocks (int): Number of residual blocks in each block.
|
||||
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
||||
temperal_upsample (list of bool): Whether to upsample temporally in each block.
|
||||
dropout (float): Dropout rate for the dropout layers.
|
||||
non_linearity (str): Type of non-linearity to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
non_linearity: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
||||
|
||||
# upsample blocks
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i > 0:
|
||||
in_dim = in_dim // 2
|
||||
|
||||
# Determine if we need upsampling
|
||||
upsample_mode = None
|
||||
if i != len(dim_mult) - 1:
|
||||
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
||||
|
||||
# Create and add the upsampling block
|
||||
up_block = WanUpBlock(
|
||||
in_dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
num_res_blocks=num_res_blocks,
|
||||
dropout=dropout,
|
||||
upsample_mode=upsample_mode,
|
||||
non_linearity=non_linearity,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
# Update scale for next iteration
|
||||
if upsample_mode is not None:
|
||||
scale *= 2.0
|
||||
|
||||
# output blocks
|
||||
self.norm_out = WanRMS_norm(out_dim, images=False)
|
||||
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_in(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_in(x)
|
||||
|
||||
## middle
|
||||
x = self.mid_block(x, feat_cache, feat_idx)
|
||||
|
||||
## upsamples
|
||||
for up_block in self.up_blocks:
|
||||
x = up_block(x, feat_cache, feat_idx)
|
||||
|
||||
## head
|
||||
x = self.norm_out(x)
|
||||
x = self.nonlinearity(x)
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
||||
x = self.conv_out(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
||||
Introduced in [Wan 2.1].
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
dropout: float = 0.0,
|
||||
latents_mean: List[float] = [
|
||||
-0.7571,
|
||||
-0.7089,
|
||||
-0.9113,
|
||||
0.1075,
|
||||
-0.1745,
|
||||
0.9653,
|
||||
-0.1517,
|
||||
1.5508,
|
||||
0.4134,
|
||||
-0.0715,
|
||||
0.5517,
|
||||
-0.3632,
|
||||
-0.1922,
|
||||
-0.9497,
|
||||
0.2503,
|
||||
-0.2921,
|
||||
],
|
||||
latents_std: List[float] = [
|
||||
2.8184,
|
||||
1.4541,
|
||||
2.3275,
|
||||
2.6558,
|
||||
1.2196,
|
||||
1.7708,
|
||||
2.6052,
|
||||
2.0743,
|
||||
3.2687,
|
||||
2.1526,
|
||||
2.8652,
|
||||
1.5579,
|
||||
1.6382,
|
||||
1.1253,
|
||||
2.8251,
|
||||
1.9160,
|
||||
],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Store normalization parameters as tensors
|
||||
self.mean = torch.tensor(latents_mean)
|
||||
self.std = torch.tensor(latents_std)
|
||||
self.scale = torch.stack([self.mean, 1.0 / self.std]) # Shape: [2, C]
|
||||
|
||||
self.z_dim = z_dim
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
self.encoder = WanEncoder3d(
|
||||
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
|
||||
)
|
||||
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
||||
|
||||
self.decoder = WanDecoder3d(
|
||||
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
|
||||
)
|
||||
|
||||
def clear_cache(self):
|
||||
def _count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, WanCausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
self._conv_num = _count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = _count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
scale = self.scale.type_as(x)
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx,
|
||||
)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
enc = self.quant_conv(out)
|
||||
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
||||
logvar = (logvar - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
|
||||
enc = torch.cat([mu, logvar], dim=1)
|
||||
self.clear_cache()
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
|
||||
|
||||
iter_ = z.shape[2]
|
||||
x = self.post_quant_conv(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
|
||||
out = torch.clamp(out, min=-1.0, max=1.0)
|
||||
self.clear_cache()
|
||||
if not return_dict:
|
||||
return (out,)
|
||||
|
||||
return DecoderOutput(sample=out)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
scale = self.scale.type_as(z)
|
||||
decoded = self._decode(z, scale).sample
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, return_dict=return_dict)
|
||||
return dec
|
||||
@@ -40,48 +40,6 @@ class SD3ControlNetOutput(BaseOutput):
|
||||
|
||||
|
||||
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, defaults to `128`):
|
||||
The width/height of the latents. This is fixed during training since it is used to learn a number of
|
||||
position embeddings.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of latent channels in the input.
|
||||
num_layers (`int`, defaults to `18`):
|
||||
The number of layers of transformer blocks to use.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
num_attention_heads (`int`, defaults to `18`):
|
||||
The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, defaults to `4096`):
|
||||
The embedding dimension to use for joint text-image attention.
|
||||
caption_projection_dim (`int`, defaults to `1152`):
|
||||
The embedding dimension of caption embeddings.
|
||||
pooled_projection_dim (`int`, defaults to `2048`):
|
||||
The embedding dimension of pooled text projections.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of latent channels in the output.
|
||||
pos_embed_max_size (`int`, defaults to `96`):
|
||||
The maximum latent height/width of positional embeddings.
|
||||
extra_conditioning_channels (`int`, defaults to `0`):
|
||||
The number of extra channels to use for conditioning for patch embedding.
|
||||
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
|
||||
The number of dual-stream transformer blocks to use.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
|
||||
pos_embed_type (`str`, defaults to `"sincos"`):
|
||||
The type of positional embedding to use. Choose between `"sincos"` and `None`.
|
||||
use_pos_embed (`bool`, defaults to `True`):
|
||||
Whether to use positional embeddings.
|
||||
force_zeros_for_pooled_projection (`bool`, defaults to `True`):
|
||||
Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
|
||||
config value of the ControlNet model.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
@@ -135,7 +93,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
context_pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
use_dual_attention=True if i in dual_attention_layers else False,
|
||||
@@ -150,7 +108,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
SD3SingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -339,28 +297,28 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`SD3Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
controlnet_cond (`torch.Tensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
@@ -479,11 +437,11 @@ class SD3MultiControlNetModel(ModelMixin):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states: torch.FloatTensor,
|
||||
controlnet_cond: List[torch.tensor],
|
||||
conditioning_scale: List[float],
|
||||
pooled_projections: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
|
||||
@@ -605,13 +605,12 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
controlnet_cond: List[torch.Tensor],
|
||||
control_type: torch.Tensor,
|
||||
control_type_idx: List[int],
|
||||
conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
from_multi: bool = False,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
@@ -648,8 +647,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
from_multi (`bool`, defaults to `False`):
|
||||
Use standard scaling when called from `MultiControlNetUnionModel`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
@@ -661,9 +658,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
if isinstance(conditioning_scale, float):
|
||||
conditioning_scale = [conditioning_scale] * len(controlnet_cond)
|
||||
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
@@ -748,16 +742,12 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
inputs = []
|
||||
condition_list = []
|
||||
|
||||
for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
|
||||
for cond, control_idx in zip(controlnet_cond, control_type_idx):
|
||||
condition = self.controlnet_cond_embedding(cond)
|
||||
feat_seq = torch.mean(condition, dim=(2, 3))
|
||||
feat_seq = feat_seq + self.task_embedding[control_idx]
|
||||
if from_multi:
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(condition)
|
||||
else:
|
||||
inputs.append(feat_seq.unsqueeze(1) * scale)
|
||||
condition_list.append(condition * scale)
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(condition)
|
||||
|
||||
condition = sample
|
||||
feat_seq = torch.mean(condition, dim=(2, 3))
|
||||
@@ -769,13 +759,10 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
x = layer(x)
|
||||
|
||||
controlnet_cond_fuser = sample * 0.0
|
||||
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
|
||||
for idx, condition in enumerate(condition_list[:-1]):
|
||||
alpha = self.spatial_ch_projs(x[:, idx])
|
||||
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
||||
if from_multi:
|
||||
controlnet_cond_fuser += condition + alpha
|
||||
else:
|
||||
controlnet_cond_fuser += condition + alpha * scale
|
||||
controlnet_cond_fuser += condition + alpha
|
||||
|
||||
sample = sample + controlnet_cond_fuser
|
||||
|
||||
@@ -819,13 +806,12 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
if from_multi:
|
||||
scales = scales * conditioning_scale[0]
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
elif from_multi:
|
||||
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
|
||||
@@ -47,12 +47,9 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
|
||||
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
|
||||
):
|
||||
if scale == 0.0:
|
||||
continue
|
||||
down_samples, mid_sample = controlnet(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
@@ -66,13 +63,12 @@ class MultiControlNetUnionModel(ModelMixin):
|
||||
attention_mask=attention_mask,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
from_multi=True,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# merge samples
|
||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||
if i == 0:
|
||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||
else:
|
||||
down_block_res_samples = [
|
||||
|
||||
@@ -2583,11 +2583,6 @@ class MultiIPAdapterImageProjection(nn.Module):
|
||||
super().__init__()
|
||||
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
||||
|
||||
@property
|
||||
def num_ip_adapters(self) -> int:
|
||||
"""Number of IP-Adapters loaded."""
|
||||
return len(self.image_projection_layers)
|
||||
|
||||
def forward(self, image_embeds: List[torch.Tensor]):
|
||||
projected_image_embeds = []
|
||||
|
||||
|
||||
@@ -166,12 +166,8 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
|
||||
# 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
|
||||
last_dtype = None
|
||||
|
||||
for name, param in parameter.named_parameters():
|
||||
for param in parameter.parameters():
|
||||
last_dtype = param.dtype
|
||||
if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
|
||||
continue
|
||||
|
||||
if param.is_floating_point():
|
||||
return param.dtype
|
||||
|
||||
@@ -870,7 +866,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
@@ -883,12 +879,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
|
||||
Executable → Regular
-2
@@ -19,7 +19,6 @@ if is_torch_available():
|
||||
from .transformer_allegro import AllegroTransformer3DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
@@ -28,4 +27,3 @@ if is_torch_available():
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
|
||||
@@ -244,34 +244,30 @@ class CogView4RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.rope_axes_dim = rope_axes_dim
|
||||
self.theta = theta
|
||||
|
||||
dim_h, dim_w = dim // 2, dim // 2
|
||||
h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h))
|
||||
w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
|
||||
h_seq = torch.arange(self.rope_axes_dim[0])
|
||||
w_seq = torch.arange(self.rope_axes_dim[1])
|
||||
self.freqs_h = torch.outer(h_seq, h_inv_freq)
|
||||
self.freqs_w = torch.outer(w_seq, w_inv_freq)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
height, width = height // self.patch_size, width // self.patch_size
|
||||
|
||||
dim_h, dim_w = self.dim // 2, self.dim // 2
|
||||
h_inv_freq = 1.0 / (
|
||||
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
|
||||
)
|
||||
w_inv_freq = 1.0 / (
|
||||
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
|
||||
)
|
||||
h_seq = torch.arange(self.rope_axes_dim[0])
|
||||
w_seq = torch.arange(self.rope_axes_dim[1])
|
||||
freqs_h = torch.outer(h_seq, h_inv_freq)
|
||||
freqs_w = torch.outer(w_seq, w_inv_freq)
|
||||
|
||||
h_idx = torch.arange(height, device=freqs_h.device)
|
||||
w_idx = torch.arange(width, device=freqs_w.device)
|
||||
h_idx = torch.arange(height)
|
||||
w_idx = torch.arange(width)
|
||||
inner_h_idx = h_idx * self.rope_axes_dim[0] // height
|
||||
inner_w_idx = w_idx * self.rope_axes_dim[1] // width
|
||||
|
||||
freqs_h = freqs_h[inner_h_idx]
|
||||
freqs_w = freqs_w[inner_w_idx]
|
||||
self.freqs_h = self.freqs_h.to(hidden_states.device)
|
||||
self.freqs_w = self.freqs_w.to(hidden_states.device)
|
||||
freqs_h = self.freqs_h[inner_h_idx]
|
||||
freqs_w = self.freqs_w[inner_w_idx]
|
||||
|
||||
# Create position matrices for height and width
|
||||
# [height, 1, dim//4] and [1, width, dim//4]
|
||||
|
||||
@@ -1,527 +0,0 @@
|
||||
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class EasyAnimateLayerNormZero(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
conditioning_dim: int,
|
||||
embedding_dim: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
norm_type: str = "fp32_layer_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
elif norm_type == "fp32_layer_norm":
|
||||
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale.unsqueeze(1)) + enc_shift.unsqueeze(
|
||||
1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states, gate, enc_gate
|
||||
|
||||
|
||||
class EasyAnimateRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, patch_size: int, rope_dim: List[int]) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.rope_dim = rope_dim
|
||||
|
||||
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
bs, c, num_frames, grid_height, grid_width = hidden_states.size()
|
||||
grid_height = grid_height // self.patch_size
|
||||
grid_width = grid_width // self.patch_size
|
||||
base_size_width = 90 // self.patch_size
|
||||
base_size_height = 60 // self.patch_size
|
||||
|
||||
grid_crops_coords = self.get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
image_rotary_emb = get_3d_rotary_pos_embed(
|
||||
self.rope_dim,
|
||||
grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=hidden_states.size(2),
|
||||
use_real=True,
|
||||
)
|
||||
return image_rotary_emb
|
||||
|
||||
|
||||
class EasyAnimateAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the EasyAnimateTransformer3DModel model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"EasyAnimateAttnProcessor2_0 requires PyTorch 2.0 or above. To use it, please install PyTorch 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# 3. Encoder condition QKV projection and normalization
|
||||
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=2)
|
||||
key = torch.cat([encoder_key, key], dim=2)
|
||||
value = torch.cat([encoder_value, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
|
||||
query[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
|
||||
)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
|
||||
key[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
|
||||
)
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# 6. Output projection
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
if getattr(attn, "to_out", None) is not None:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if getattr(attn, "to_add_out", None) is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
else:
|
||||
if getattr(attn, "to_out", None) is not None:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class EasyAnimateTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
time_embed_dim: int,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-6,
|
||||
final_dropout: bool = True,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
qk_norm: bool = True,
|
||||
after_norm: bool = False,
|
||||
norm_type: str = "fp32_layer_norm",
|
||||
is_mmdit_block: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Attention Part
|
||||
self.norm1 = EasyAnimateLayerNormZero(
|
||||
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
|
||||
)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
qk_norm="layer_norm" if qk_norm else None,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
added_proj_bias=True,
|
||||
added_kv_proj_dim=dim if is_mmdit_block else None,
|
||||
context_pre_only=False if is_mmdit_block else None,
|
||||
processor=EasyAnimateAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# FFN Part
|
||||
self.norm2 = EasyAnimateLayerNormZero(
|
||||
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
|
||||
)
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
self.txt_ff = None
|
||||
if is_mmdit_block:
|
||||
self.txt_ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
|
||||
self.norm3 = None
|
||||
if after_norm:
|
||||
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Attention
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
|
||||
|
||||
# 2. Feed-forward
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
if self.norm3 is not None:
|
||||
norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
|
||||
if self.txt_ff is not None:
|
||||
norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
|
||||
else:
|
||||
norm_encoder_hidden_states = self.norm3(self.ff(norm_encoder_hidden_states))
|
||||
else:
|
||||
norm_hidden_states = self.ff(norm_hidden_states)
|
||||
if self.txt_ff is not None:
|
||||
norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
|
||||
else:
|
||||
norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states)
|
||||
hidden_states = hidden_states + gate_ff.unsqueeze(1) * norm_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff.unsqueeze(1) * norm_encoder_hidden_states
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, defaults to `48`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
sample_width (`int`, defaults to `90`):
|
||||
The width of the input latents.
|
||||
sample_height (`int`, defaults to `60`):
|
||||
The height of the input latents.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
||||
Activation function to use when generating the timestep embeddings.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
mmdit_layers (`int`, defaults to `1000`):
|
||||
The number of layers of Multi Modal Transformer blocks to use.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
The epsilon value to use in normalization layers.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether to use elementwise affine in normalization layers.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
time_position_encoding_type (`str`, defaults to `3d_rope`):
|
||||
Type of time position encoding.
|
||||
after_norm (`bool`, defaults to `False`):
|
||||
Flag to apply normalization after.
|
||||
resize_inpaint_mask_directly (`bool`, defaults to `True`):
|
||||
Flag to resize inpaint mask directly.
|
||||
enable_text_attention_mask (`bool`, defaults to `True`):
|
||||
Flag to enable text attention mask.
|
||||
add_noise_in_inpaint_model (`bool`, defaults to `False`):
|
||||
Flag to add noise in inpaint model.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["EasyAnimateTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["^proj$", "norm", "^proj_out$"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 48,
|
||||
attention_head_dim: int = 64,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
sample_width: int = 90,
|
||||
sample_height: int = 60,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
timestep_activation_fn: str = "silu",
|
||||
freq_shift: int = 0,
|
||||
num_layers: int = 48,
|
||||
mmdit_layers: int = 48,
|
||||
dropout: float = 0.0,
|
||||
time_embed_dim: int = 512,
|
||||
add_norm_text_encoder: bool = False,
|
||||
text_embed_dim: int = 3584,
|
||||
text_embed_dim_t5: int = None,
|
||||
norm_eps: float = 1e-5,
|
||||
norm_elementwise_affine: bool = True,
|
||||
flip_sin_to_cos: bool = True,
|
||||
time_position_encoding_type: str = "3d_rope",
|
||||
after_norm=False,
|
||||
resize_inpaint_mask_directly: bool = True,
|
||||
enable_text_attention_mask: bool = True,
|
||||
add_noise_in_inpaint_model: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Timestep embedding
|
||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||
self.rope_embedding = EasyAnimateRotaryPosEmbed(patch_size, attention_head_dim)
|
||||
|
||||
# 2. Patch embedding
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
|
||||
)
|
||||
|
||||
# 3. Text refined embedding
|
||||
self.text_proj = None
|
||||
self.text_proj_t5 = None
|
||||
if not add_norm_text_encoder:
|
||||
self.text_proj = nn.Linear(text_embed_dim, inner_dim)
|
||||
if text_embed_dim_t5 is not None:
|
||||
self.text_proj_t5 = nn.Linear(text_embed_dim_t5, inner_dim)
|
||||
else:
|
||||
self.text_proj = nn.Sequential(
|
||||
RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim, inner_dim)
|
||||
)
|
||||
if text_embed_dim_t5 is not None:
|
||||
self.text_proj_t5 = nn.Sequential(
|
||||
RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim_t5, inner_dim)
|
||||
)
|
||||
|
||||
# 4. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
EasyAnimateTransformerBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
after_norm=after_norm,
|
||||
is_mmdit_block=True if _ < mmdit_layers else False,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
||||
|
||||
# 5. Output norm & projection
|
||||
self.norm_out = AdaLayerNorm(
|
||||
embedding_dim=time_embed_dim,
|
||||
output_dim=2 * inner_dim,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
chunk_dim=1,
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_t5: Optional[torch.Tensor] = None,
|
||||
inpaint_latents: Optional[torch.Tensor] = None,
|
||||
control_latents: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
batch_size, channels, video_length, height, width = hidden_states.size()
|
||||
p = self.config.patch_size
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
|
||||
# 1. Time embedding
|
||||
temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
|
||||
temb = self.time_embedding(temb, timestep_cond)
|
||||
image_rotary_emb = self.rope_embedding(hidden_states)
|
||||
|
||||
# 2. Patch embedding
|
||||
if inpaint_latents is not None:
|
||||
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
|
||||
if control_latents is not None:
|
||||
hidden_states = torch.concat([hidden_states, control_latents], 1)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, F, H, W] -> [BF, C, H, W]
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
|
||||
0, 2, 1, 3, 4
|
||||
) # [BF, C, H, W] -> [B, F, C, H, W]
|
||||
hidden_states = hidden_states.flatten(2, 4).transpose(1, 2) # [B, F, C, H, W] -> [B, FHW, C]
|
||||
|
||||
# 3. Text embedding
|
||||
encoder_hidden_states = self.text_proj(encoder_hidden_states)
|
||||
if encoder_hidden_states_t5 is not None:
|
||||
encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous()
|
||||
|
||||
# 4. Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, image_rotary_emb
|
||||
)
|
||||
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
|
||||
# 5. Output norm & projection
|
||||
hidden_states = self.norm_out(hidden_states, temb=temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 6. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, video_length, post_patch_height, post_patch_width, channels, p, p)
|
||||
output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -18,6 +18,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
@@ -31,7 +32,7 @@ from ...models.attention_processor import (
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -44,7 +45,20 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxSingleTransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
||||
r"""
|
||||
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
||||
processing of `context` conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
||||
super().__init__()
|
||||
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
@@ -54,15 +68,9 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
if is_torch_npu_available():
|
||||
deprecation_message = (
|
||||
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
||||
"should be set explicitly using the `set_attn_processor` method."
|
||||
)
|
||||
deprecate("npu_processor", "0.34.0", deprecation_message)
|
||||
processor = FluxAttnProcessor2_0_NPU()
|
||||
else:
|
||||
processor = FluxAttnProcessor2_0()
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
@@ -105,14 +113,39 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
|
||||
Args:
|
||||
dim (`int`):
|
||||
The embedding dimension of the block.
|
||||
num_attention_heads (`int`):
|
||||
The number of attention heads to use.
|
||||
attention_head_dim (`int`):
|
||||
The number of dimensions to use for each attention head.
|
||||
qk_norm (`str`, defaults to `"rms_norm"`):
|
||||
The normalization to use for the query and key tensors.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
The epsilon value to use for the normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
|
||||
if hasattr(F, "scaled_dot_product_attention"):
|
||||
processor = FluxAttnProcessor2_0()
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
||||
)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
@@ -122,7 +155,7 @@ class FluxTransformerBlock(nn.Module):
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=FluxAttnProcessor2_0(),
|
||||
processor=processor,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
)
|
||||
@@ -133,6 +166,10 @@ class FluxTransformerBlock(nn.Module):
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -581,11 +581,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
self.context_embedder = HunyuanVideoTokenRefiner(
|
||||
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
||||
)
|
||||
|
||||
if guidance_embeds:
|
||||
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
||||
else:
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
||||
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
||||
|
||||
# 2. RoPE
|
||||
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
||||
@@ -712,11 +708,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Conditional embeddings
|
||||
if self.config.guidance_embeds:
|
||||
temb = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
else:
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
|
||||
temb = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
||||
@@ -38,6 +39,17 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class SD3SingleTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -47,13 +59,21 @@ class SD3SingleTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
|
||||
if hasattr(F, "scaled_dot_product_attention"):
|
||||
processor = JointAttnProcessor2_0()
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
||||
)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=JointAttnProcessor2_0(),
|
||||
processor=processor,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
@@ -61,17 +81,23 @@ class SD3SingleTransformerBlock(nn.Module):
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
|
||||
# 1. Attention
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
attn_output = self.attn(hidden_states=norm_hidden_states, encoder_hidden_states=None)
|
||||
# Attention.
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 2. Feed Forward
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
return hidden_states
|
||||
@@ -81,40 +107,26 @@ class SD3Transformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
|
||||
The Transformer model introduced in Stable Diffusion 3.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, defaults to `128`):
|
||||
The width/height of the latents. This is fixed during training since it is used to learn a number of
|
||||
position embeddings.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of latent channels in the input.
|
||||
num_layers (`int`, defaults to `18`):
|
||||
The number of layers of transformer blocks to use.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
num_attention_heads (`int`, defaults to `18`):
|
||||
The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, defaults to `4096`):
|
||||
The embedding dimension to use for joint text-image attention.
|
||||
caption_projection_dim (`int`, defaults to `1152`):
|
||||
The embedding dimension of caption embeddings.
|
||||
pooled_projection_dim (`int`, defaults to `2048`):
|
||||
The embedding dimension of pooled text projections.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of latent channels in the output.
|
||||
pos_embed_max_size (`int`, defaults to `96`):
|
||||
The maximum latent height/width of positional embeddings.
|
||||
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
|
||||
The number of dual-stream transformer blocks to use.
|
||||
qk_norm (`str`, *optional*, defaults to `None`):
|
||||
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
||||
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
||||
out_channels (`int`, defaults to 16): Number of output channels.
|
||||
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["JointTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
@@ -137,33 +149,36 @@ class SD3Transformer2DModel(
|
||||
qk_norm: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels if out_channels is not None else in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
default_out_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
height=self.config.sample_size,
|
||||
width=self.config.sample_size,
|
||||
patch_size=self.config.patch_size,
|
||||
in_channels=self.config.in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
|
||||
)
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
|
||||
)
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
||||
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
|
||||
|
||||
# `attention_head_dim` is doubled to account for the mixing.
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
context_pre_only=i == num_layers - 1,
|
||||
qk_norm=qk_norm,
|
||||
use_dual_attention=True if i in dual_attention_layers else False,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -316,24 +331,24 @@ class SD3Transformer2DModel(
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
block_controlnet_hidden_states: List = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
skip_layers: Optional[List[int]] = None,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`SD3Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`):
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
|
||||
Embeddings projected from the embeddings of input conditions.
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :257]
|
||||
encoder_hidden_states = encoder_hidden_states[:, 257:]
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
||||
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
|
||||
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
||||
return x_out.type_as(hidden_states)
|
||||
|
||||
query = apply_rotary_emb(query, rotary_emb)
|
||||
key = apply_rotary_emb(key, rotary_emb)
|
||||
|
||||
# I2V task
|
||||
hidden_states_img = None
|
||||
if encoder_hidden_states_img is not None:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
key_img = attn.norm_added_k(key_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
hidden_states_img = F.scaled_dot_product_attention(
|
||||
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
if hidden_states_img is not None:
|
||||
hidden_states = hidden_states + hidden_states_img
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanImageEmbedding(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = FP32LayerNorm(in_features)
|
||||
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
|
||||
self.norm2 = FP32LayerNorm(out_features)
|
||||
|
||||
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm1(encoder_hidden_states_image)
|
||||
hidden_states = self.ff(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_freq_dim: int,
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
image_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||
|
||||
self.image_embedder = None
|
||||
if image_embed_dim is not None:
|
||||
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
||||
|
||||
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
||||
|
||||
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
freqs = []
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
freq = get_1d_rotary_pos_embed(
|
||||
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
|
||||
)
|
||||
freqs.append(freq)
|
||||
self.freqs = torch.cat(freqs, dim=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
self.freqs = self.freqs.to(hidden_states.device)
|
||||
freqs = self.freqs.split_with_sizes(
|
||||
[
|
||||
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
|
||||
self.attention_head_dim // 6,
|
||||
self.attention_head_dim // 6,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
||||
return freqs
|
||||
|
||||
|
||||
class WanTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
added_proj_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
||||
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
||||
hidden_states
|
||||
)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the Wan model.
|
||||
|
||||
Args:
|
||||
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
||||
num_attention_heads (`int`, defaults to `40`):
|
||||
Fixed length for text embeddings.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_dim (`int`, defaults to `512`):
|
||||
Input dimension for text embeddings.
|
||||
freq_dim (`int`, defaults to `256`):
|
||||
Dimension for sinusoidal time embeddings.
|
||||
ffn_dim (`int`, defaults to `13824`):
|
||||
Intermediate dimension in feed-forward network.
|
||||
num_layers (`int`, defaults to `40`):
|
||||
The number of layers of transformer blocks to use.
|
||||
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
||||
Window size for local attention (-1 indicates global attention).
|
||||
cross_attn_norm (`bool`, defaults to `True`):
|
||||
Enable cross-attention normalization.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Enable query/key normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
add_img_emb (`bool`, defaults to `False`):
|
||||
Whether to use img_emb.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
||||
_no_split_modules = ["WanTransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
text_dim: int = 4096,
|
||||
freq_dim: int = 256,
|
||||
ffn_dim: int = 13824,
|
||||
num_layers: int = 40,
|
||||
cross_attn_norm: bool = True,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
eps: float = 1e-6,
|
||||
image_dim: Optional[int] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
rope_max_seq_len: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Condition embeddings
|
||||
# image_embedding_dim=1280 for I2V model
|
||||
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||
dim=inner_dim,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
image_embed_dim=image_dim,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanTransformerBlock(
|
||||
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
||||
timestep, encoder_hidden_states, encoder_hidden_states_image
|
||||
)
|
||||
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
||||
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
||||
|
||||
# 5. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -240,6 +240,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
|
||||
@@ -216,17 +216,8 @@ else:
|
||||
"IFPipeline",
|
||||
"IFSuperResolutionPipeline",
|
||||
]
|
||||
_import_structure["easyanimate"] = [
|
||||
"EasyAnimatePipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
]
|
||||
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
|
||||
_import_structure["hunyuan_video"] = [
|
||||
"HunyuanVideoPipeline",
|
||||
"HunyuanSkyreelsImageToVideoPipeline",
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
]
|
||||
_import_structure["hunyuan_video"] = ["HunyuanVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline"]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
"KandinskyImg2ImgCombinedPipeline",
|
||||
@@ -270,7 +261,6 @@ else:
|
||||
_import_structure["marigold"].extend(
|
||||
[
|
||||
"MarigoldDepthPipeline",
|
||||
"MarigoldIntrinsicsPipeline",
|
||||
"MarigoldNormalsPipeline",
|
||||
]
|
||||
)
|
||||
@@ -356,7 +346,6 @@ else:
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline"]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -555,11 +544,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
from .easyanimate import (
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
)
|
||||
from .flux import (
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
@@ -574,11 +558,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .hunyuan_video import (
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
)
|
||||
from .hunyuan_video import HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoPipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .i2vgen_xl import I2VGenXLPipeline
|
||||
from .kandinsky import (
|
||||
@@ -623,7 +603,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .lumina2 import Lumina2Text2ImgPipeline
|
||||
from .marigold import (
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldIntrinsicsPipeline,
|
||||
MarigoldNormalsPipeline,
|
||||
)
|
||||
from .mochi import MochiPipeline
|
||||
@@ -709,7 +688,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
@@ -83,7 +83,6 @@ class AnimateDiffPipeline(
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
@@ -125,7 +125,6 @@ class AnimateDiffControlNetPipeline(
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-video generation with ControlNet guidance.
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -136,7 +136,6 @@ class AnimateDiffSparseControlNetPipeline(
|
||||
IPAdapterMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...models.unets.unet_motion_model import MotionAdapter
|
||||
@@ -186,7 +186,6 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for video-to-video generation.
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
@@ -204,7 +204,6 @@ class AnimateDiffVideoToVideoControlNetPipeline(
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FreeInitMixin,
|
||||
AnimateDiffFreeNoiseMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for video-to-video generation with ControlNet guidance.
|
||||
|
||||
@@ -34,10 +34,6 @@ from .controlnet import (
|
||||
StableDiffusionXLControlNetUnionInpaintPipeline,
|
||||
StableDiffusionXLControlNetUnionPipeline,
|
||||
)
|
||||
from .controlnet_sd3 import (
|
||||
StableDiffusion3ControlNetInpaintingPipeline,
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
||||
from .flux import (
|
||||
FluxControlImg2ImgPipeline,
|
||||
@@ -124,7 +120,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
|
||||
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
|
||||
("stable-diffusion-3-controlnet", StableDiffusion3ControlNetPipeline),
|
||||
("wuerstchen", WuerstchenCombinedPipeline),
|
||||
("cascade", StableCascadeCombinedPipeline),
|
||||
("lcm", LatentConsistencyModelPipeline),
|
||||
@@ -183,7 +178,6 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
|
||||
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline),
|
||||
("stable-diffusion-3-controlnet", StableDiffusion3ControlNetInpaintingPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
|
||||
("flux", FluxInpaintPipeline),
|
||||
("flux-controlnet", FluxControlNetInpaintPipeline),
|
||||
|
||||
@@ -143,11 +143,13 @@ class CogView4Pipeline(DiffusionPipeline):
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`GLMModel`]):
|
||||
Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. CogView4 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
||||
tokenizer (`T5Tokenizer`):
|
||||
Tokenizer of class
|
||||
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
transformer ([`CogView4Transformer2DModel`]):
|
||||
A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
@@ -213,7 +215,7 @@ class CogView4Pipeline(DiffusionPipeline):
|
||||
)
|
||||
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(self.text_encoder.device), output_hidden_states=True
|
||||
text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
@@ -360,16 +362,10 @@ class CogView4Pipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
|
||||
@@ -207,7 +207,7 @@ class StableDiffusionControlNetPipeline(
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "image"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1323,7 +1323,6 @@ class StableDiffusionControlNetPipeline(
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
image = callback_outputs.pop("image", image)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -185,7 +185,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1294,7 +1294,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
control_image = callback_outputs.pop("control_image", control_image)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -184,7 +184,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1476,7 +1476,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
control_image = callback_outputs.pop("control_image", control_image)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -237,7 +237,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
"add_neg_time_ids",
|
||||
"mask",
|
||||
"masked_image_latents",
|
||||
"control_image",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -744,7 +743,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -752,7 +751,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
@@ -1645,7 +1644,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
@@ -1836,7 +1835,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
control_image = callback_outputs.pop("control_image", control_image)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -242,7 +242,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
"add_time_ids",
|
||||
"negative_pooled_prompt_embeds",
|
||||
"add_neg_time_ids",
|
||||
"control_image",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -1615,7 +1614,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
)
|
||||
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
||||
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
|
||||
control_image = callback_outputs.pop("control_image", control_image)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -219,7 +219,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
"add_time_ids",
|
||||
"mask",
|
||||
"masked_image_latents",
|
||||
"control_image",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@@ -727,7 +726,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
@@ -735,7 +734,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
@@ -1744,7 +1743,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
control_image = callback_outputs.pop("control_image", control_image)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -757,9 +757,15 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
for images_ in image:
|
||||
for image_ in images_:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
# TODO Update for https://github.com/huggingface/diffusers/pull/10723
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
elif isinstance(controlnet, MultiControlNetUnionModel):
|
||||
if isinstance(controlnet_conditioning_scale, list):
|
||||
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
||||
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
|
||||
@@ -770,6 +776,8 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
||||
" the same length as the number of controlnets"
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if len(control_guidance_start) != len(control_guidance_end):
|
||||
raise ValueError(
|
||||
@@ -800,6 +808,8 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
|
||||
if max(_control_mode) >= _controlnet.config.num_control_type:
|
||||
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Equal number of `image` and `control_mode` elements
|
||||
if isinstance(controlnet, ControlNetUnionModel):
|
||||
@@ -813,6 +823,8 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
|
||||
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
|
||||
raise ValueError("Expected len(control_image) == len(control_mode)")
|
||||
else:
|
||||
assert False
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
@@ -1189,6 +1201,18 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else 1
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
if not isinstance(control_image, list):
|
||||
control_image = [control_image]
|
||||
else:
|
||||
@@ -1197,25 +1221,8 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
if not isinstance(control_mode, list):
|
||||
control_mode = [control_mode]
|
||||
|
||||
if isinstance(controlnet, MultiControlNetUnionModel):
|
||||
control_image = [[item] for item in control_image]
|
||||
control_mode = [[item] for item in control_mode]
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
control_guidance_start, control_guidance_end = (
|
||||
mult * [control_guidance_start],
|
||||
mult * [control_guidance_end],
|
||||
)
|
||||
|
||||
if isinstance(controlnet_conditioning_scale, float):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
|
||||
if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
@@ -1350,6 +1357,9 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
control_image = control_images
|
||||
height, width = control_image[0][0].shape[-2:]
|
||||
|
||||
else:
|
||||
assert False
|
||||
|
||||
# 5. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
@@ -1387,7 +1397,7 @@ class StableDiffusionXLControlNetUnionPipeline(
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps)
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetUnionModel) else keeps)
|
||||
|
||||
# 7.2 Prepare added time ids & embeddings
|
||||
original_size = original_size or (height, width)
|
||||
|
||||
@@ -252,7 +252,12 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
"feature_extractor",
|
||||
"image_encoder",
|
||||
]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "add_text_embeds", "add_time_ids", "control_image"]
|
||||
_callback_tensor_inputs = [
|
||||
"latents",
|
||||
"prompt_embeds",
|
||||
"add_text_embeds",
|
||||
"add_time_ids",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1557,7 +1562,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
||||
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
||||
control_image = callback_outputs.pop("control_image", control_image)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_easyanimate"] = ["EasyAnimatePipeline"]
|
||||
_import_structure["pipeline_easyanimate_control"] = ["EasyAnimateControlPipeline"]
|
||||
_import_structure["pipeline_easyanimate_inpaint"] = ["EasyAnimateInpaintPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_easyanimate import EasyAnimatePipeline
|
||||
from .pipeline_easyanimate_control import EasyAnimateControlPipeline
|
||||
from .pipeline_easyanimate_inpaint import EasyAnimateInpaintPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,770 +0,0 @@
|
||||
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
Qwen2Tokenizer,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from .pipeline_output import EasyAnimatePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import EasyAnimatePipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh"
|
||||
>>> pipe = EasyAnimatePipeline.from_pretrained(
|
||||
... "alibaba-pai/EasyAnimateV5.1-7b-zh-diffusers", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
>>> prompt = (
|
||||
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
||||
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
||||
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||
... "atmosphere of this unique musical performance."
|
||||
... )
|
||||
>>> sample_size = (512, 512)
|
||||
>>> video = pipe(
|
||||
... prompt=prompt,
|
||||
... guidance_scale=6,
|
||||
... negative_prompt="bad detailed",
|
||||
... height=sample_size[0],
|
||||
... width=sample_size[1],
|
||||
... num_inference_steps=50,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class EasyAnimatePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using EasyAnimate.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKLMagvit`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
||||
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
|
||||
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
||||
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
|
||||
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
|
||||
transformer ([`EasyAnimateTransformer3DModel`]):
|
||||
The EasyAnimate model designed by EasyAnimate Team.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLMagvit,
|
||||
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
|
||||
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
|
||||
transformer: EasyAnimateTransformer3DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.enable_text_attention_mask = (
|
||||
self.transformer.config.enable_text_attention_mask
|
||||
if getattr(self, "transformer", None) is not None
|
||||
else True
|
||||
)
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
dtype (`torch.dtype`):
|
||||
torch dtype
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
|
||||
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
|
||||
"""
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
device = device or self.text_encoder.device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
if isinstance(prompt, str):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}],
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": _prompt}],
|
||||
}
|
||||
for _prompt in prompt
|
||||
]
|
||||
text = [
|
||||
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
|
||||
]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text=text,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
padding_side="right",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs.to(self.text_encoder.device)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
if self.enable_text_attention_mask:
|
||||
# Inference: Generation of the output
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
else:
|
||||
raise ValueError("LLM needs attention_mask")
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
if negative_prompt is not None and isinstance(negative_prompt, str):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": negative_prompt}],
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": _negative_prompt}],
|
||||
}
|
||||
for _negative_prompt in negative_prompt
|
||||
]
|
||||
text = [
|
||||
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
|
||||
]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text=text,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
padding_side="right",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs.to(self.text_encoder.device)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
negative_prompt_attention_mask = text_inputs.attention_mask
|
||||
if self.enable_text_attention_mask:
|
||||
# Inference: Generation of the output
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=negative_prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
else:
|
||||
raise ValueError("LLM needs attention_mask")
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
(num_frames - 1) // self.vae_temporal_compression_ratio + 1,
|
||||
height // self.vae_spatial_compression_ratio,
|
||||
width // self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
if hasattr(self.scheduler, "init_noise_sigma"):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_frames: Optional[int] = 49,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
guidance_rescale: float = 0.0,
|
||||
):
|
||||
r"""
|
||||
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
|
||||
|
||||
Examples:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
|
||||
num_frames (`int`, *optional*):
|
||||
Length of the generated video (in frames).
|
||||
height (`int`, *optional*):
|
||||
Height of the generated image in pixels.
|
||||
width (`int`, *optional*):
|
||||
Width of the generated image in pixels.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
Number of denoising steps during generation. More steps generally yield higher quality images but slow
|
||||
down inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of images to generate for each prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A generator to ensure reproducibility in image generation.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Predefined latent tensors to condition generation.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Embeddings for negative prompts. Overrides string inputs if defined.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the primary prompt embeddings.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for negative prompt embeddings.
|
||||
output_type (`str`, *optional*, defaults to "latent"):
|
||||
Format of the generated output, either as a PIL image or as a NumPy array.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
If `True`, returns a structured output. Otherwise returns a simple tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
Functions called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
||||
Tensor names to be included in callback function calls.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Adjusts noise levels based on guidance scale.
|
||||
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
|
||||
Original dimensions of the output.
|
||||
target_size (`Tuple[int, int]`, *optional*):
|
||||
Desired output dimensions for calculations.
|
||||
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
|
||||
Coordinates for cropping.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 0. default height and width
|
||||
height = int((height // 16) * 16)
|
||||
width = int((width // 16) * 16)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, mu=1
|
||||
)
|
||||
else:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
if hasattr(self.scheduler, "scale_model_input"):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
||||
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
||||
dtype=latent_model_input.dtype
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
t_expand,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
||||
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return EasyAnimatePipelineOutput(frames=video)
|
||||
@@ -1,994 +0,0 @@
|
||||
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
Qwen2Tokenizer,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from .pipeline_output import EasyAnimatePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import EasyAnimateControlPipeline
|
||||
>>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_control import get_video_to_video_latent
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> pipe = EasyAnimateControlPipeline.from_pretrained(
|
||||
... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> control_video = load_video(
|
||||
... "https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control/blob/main/asset/pose.mp4"
|
||||
... )
|
||||
>>> prompt = (
|
||||
... "In this sunlit outdoor garden, a beautiful woman is dressed in a knee-length, sleeveless white dress. "
|
||||
... "The hem of her dress gently sways with her graceful dance, much like a butterfly fluttering in the breeze. "
|
||||
... "Sunlight filters through the leaves, casting dappled shadows that highlight her soft features and clear eyes, "
|
||||
... "making her appear exceptionally elegant. It seems as if every movement she makes speaks of youth and vitality. "
|
||||
... "As she twirls on the grass, her dress flutters, as if the entire garden is rejoicing in her dance. "
|
||||
... "The colorful flowers around her sway in the gentle breeze, with roses, chrysanthemums, and lilies each "
|
||||
... "releasing their fragrances, creating a relaxed and joyful atmosphere."
|
||||
... )
|
||||
>>> sample_size = (672, 384)
|
||||
>>> num_frames = 49
|
||||
|
||||
>>> input_video, _, _ = get_video_to_video_latent(control_video, num_frames, sample_size)
|
||||
>>> video = pipe(
|
||||
... prompt,
|
||||
... num_frames=num_frames,
|
||||
... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.",
|
||||
... height=sample_size[0],
|
||||
... width=sample_size[1],
|
||||
... control_video=input_video,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def preprocess_image(image, sample_size):
|
||||
"""
|
||||
Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
|
||||
"""
|
||||
if isinstance(image, torch.Tensor):
|
||||
# If input is a tensor, assume it's in CHW format and resize using interpolation
|
||||
image = torch.nn.functional.interpolate(
|
||||
image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
elif isinstance(image, Image.Image):
|
||||
# If input is a PIL image, resize and convert to numpy array
|
||||
image = image.resize((sample_size[1], sample_size[0]))
|
||||
image = np.array(image)
|
||||
elif isinstance(image, np.ndarray):
|
||||
# If input is a numpy array, resize using PIL
|
||||
image = Image.fromarray(image).resize((sample_size[1], sample_size[0]))
|
||||
image = np.array(image)
|
||||
else:
|
||||
raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.")
|
||||
|
||||
# Convert to tensor if not already
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1]
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_video_to_video_latent(input_video, num_frames, sample_size, validation_video_mask=None, ref_image=None):
|
||||
if input_video is not None:
|
||||
# Convert each frame in the list to tensor
|
||||
input_video = [preprocess_image(frame, sample_size=sample_size) for frame in input_video]
|
||||
|
||||
# Stack all frames into a single tensor (F, C, H, W)
|
||||
input_video = torch.stack(input_video)[:num_frames]
|
||||
|
||||
# Add batch dimension (B, F, C, H, W)
|
||||
input_video = input_video.permute(1, 0, 2, 3).unsqueeze(0)
|
||||
|
||||
if validation_video_mask is not None:
|
||||
# Handle mask input
|
||||
validation_video_mask = preprocess_image(validation_video_mask, size=sample_size)
|
||||
input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255)
|
||||
|
||||
# Adjust mask dimensions to match video
|
||||
input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
|
||||
input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
|
||||
input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
|
||||
else:
|
||||
input_video_mask = torch.zeros_like(input_video[:, :1])
|
||||
input_video_mask[:, :, :] = 255
|
||||
else:
|
||||
input_video, input_video_mask = None, None
|
||||
|
||||
if ref_image is not None:
|
||||
# Convert reference image to tensor
|
||||
ref_image = preprocess_image(ref_image, size=sample_size)
|
||||
ref_image = ref_image.permute(1, 0, 2, 3).unsqueeze(0) # Add batch dimension (B, C, H, W)
|
||||
else:
|
||||
ref_image = None
|
||||
|
||||
return input_video, input_video_mask, ref_image
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Resize mask information in magvit
|
||||
def resize_mask(mask, latent, process_first_frame_only=True):
|
||||
latent_size = latent.size()
|
||||
|
||||
if process_first_frame_only:
|
||||
target_size = list(latent_size[2:])
|
||||
target_size[0] = 1
|
||||
first_frame_resized = F.interpolate(
|
||||
mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False
|
||||
)
|
||||
|
||||
target_size = list(latent_size[2:])
|
||||
target_size[0] = target_size[0] - 1
|
||||
if target_size[0] != 0:
|
||||
remaining_frames_resized = F.interpolate(
|
||||
mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False
|
||||
)
|
||||
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
|
||||
else:
|
||||
resized_mask = first_frame_resized
|
||||
else:
|
||||
target_size = list(latent_size[2:])
|
||||
resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False)
|
||||
return resized_mask
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class EasyAnimateControlPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using EasyAnimate.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKLMagvit`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
|
||||
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
|
||||
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
|
||||
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
|
||||
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
|
||||
transformer ([`EasyAnimateTransformer3DModel`]):
|
||||
The EasyAnimate model designed by EasyAnimate Team.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLMagvit,
|
||||
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
|
||||
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
|
||||
transformer: EasyAnimateTransformer3DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.enable_text_attention_mask = (
|
||||
self.transformer.config.enable_text_attention_mask
|
||||
if getattr(self, "transformer", None) is not None
|
||||
else True
|
||||
)
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_spatial_compression_ratio,
|
||||
do_normalize=False,
|
||||
do_binarize=True,
|
||||
do_convert_grayscale=True,
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
|
||||
# Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
dtype (`torch.dtype`):
|
||||
torch dtype
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
|
||||
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
|
||||
"""
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
device = device or self.text_encoder.device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
if isinstance(prompt, str):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": prompt}],
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": _prompt}],
|
||||
}
|
||||
for _prompt in prompt
|
||||
]
|
||||
text = [
|
||||
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
|
||||
]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text=text,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
padding_side="right",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs.to(self.text_encoder.device)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
if self.enable_text_attention_mask:
|
||||
# Inference: Generation of the output
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
else:
|
||||
raise ValueError("LLM needs attention_mask")
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
if negative_prompt is not None and isinstance(negative_prompt, str):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": negative_prompt}],
|
||||
}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": _negative_prompt}],
|
||||
}
|
||||
for _negative_prompt in negative_prompt
|
||||
]
|
||||
text = [
|
||||
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
|
||||
]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text=text,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
padding_side="right",
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs.to(self.text_encoder.device)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
negative_prompt_attention_mask = text_inputs.attention_mask
|
||||
if self.enable_text_attention_mask:
|
||||
# Inference: Generation of the output
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=negative_prompt_attention_mask,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
else:
|
||||
raise ValueError("LLM needs attention_mask")
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
|
||||
):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
(num_frames - 1) // self.vae_temporal_compression_ratio + 1,
|
||||
height // self.vae_spatial_compression_ratio,
|
||||
width // self.vae_spatial_compression_ratio,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
if hasattr(self.scheduler, "init_noise_sigma"):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_control_latents(
|
||||
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
||||
):
|
||||
# resize the control to latents shape as we concatenate the control to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
|
||||
if control is not None:
|
||||
control = control.to(device=device, dtype=dtype)
|
||||
bs = 1
|
||||
new_control = []
|
||||
for i in range(0, control.shape[0], bs):
|
||||
control_bs = control[i : i + bs]
|
||||
control_bs = self.vae.encode(control_bs)[0]
|
||||
control_bs = control_bs.mode()
|
||||
new_control.append(control_bs)
|
||||
control = torch.cat(new_control, dim=0)
|
||||
control = control * self.vae.config.scaling_factor
|
||||
|
||||
if control_image is not None:
|
||||
control_image = control_image.to(device=device, dtype=dtype)
|
||||
bs = 1
|
||||
new_control_pixel_values = []
|
||||
for i in range(0, control_image.shape[0], bs):
|
||||
control_pixel_values_bs = control_image[i : i + bs]
|
||||
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
|
||||
control_pixel_values_bs = control_pixel_values_bs.mode()
|
||||
new_control_pixel_values.append(control_pixel_values_bs)
|
||||
control_image_latents = torch.cat(new_control_pixel_values, dim=0)
|
||||
control_image_latents = control_image_latents * self.vae.config.scaling_factor
|
||||
else:
|
||||
control_image_latents = None
|
||||
|
||||
return control, control_image_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_frames: Optional[int] = 49,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
control_video: Union[torch.FloatTensor] = None,
|
||||
control_camera_video: Union[torch.FloatTensor] = None,
|
||||
ref_image: Union[torch.FloatTensor] = None,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
guidance_rescale: float = 0.0,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
r"""
|
||||
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
|
||||
|
||||
Examples:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
|
||||
num_frames (`int`, *optional*):
|
||||
Length of the generated video (in frames).
|
||||
height (`int`, *optional*):
|
||||
Height of the generated image in pixels.
|
||||
width (`int`, *optional*):
|
||||
Width of the generated image in pixels.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
Number of denoising steps during generation. More steps generally yield higher quality images but slow
|
||||
down inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of images to generate for each prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A generator to ensure reproducibility in image generation.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Predefined latent tensors to condition generation.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Embeddings for negative prompts. Overrides string inputs if defined.
|
||||
prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for the primary prompt embeddings.
|
||||
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for negative prompt embeddings.
|
||||
output_type (`str`, *optional*, defaults to "latent"):
|
||||
Format of the generated output, either as a PIL image or as a NumPy array.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
If `True`, returns a structured output. Otherwise returns a simple tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
Functions called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
||||
Tensor names to be included in callback function calls.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Adjusts noise levels based on guidance scale.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 0. default height and width
|
||||
height = int((height // 16) * 16)
|
||||
width = int((width // 16) * 16)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
if self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
text_encoder_index=0,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, mu=1
|
||||
)
|
||||
else:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if control_camera_video is not None:
|
||||
control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True)
|
||||
control_video_latents = control_video_latents * 6
|
||||
control_latents = (
|
||||
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
|
||||
).to(device, dtype)
|
||||
elif control_video is not None:
|
||||
batch_size, channels, num_frames, height_video, width_video = control_video.shape
|
||||
control_video = self.image_processor.preprocess(
|
||||
control_video.permute(0, 2, 1, 3, 4).reshape(
|
||||
batch_size * num_frames, channels, height_video, width_video
|
||||
),
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
control_video = control_video.to(dtype=torch.float32)
|
||||
control_video = control_video.reshape(batch_size, num_frames, channels, height, width).permute(
|
||||
0, 2, 1, 3, 4
|
||||
)
|
||||
control_video_latents = self.prepare_control_latents(
|
||||
None,
|
||||
control_video,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
self.do_classifier_free_guidance,
|
||||
)[1]
|
||||
control_latents = (
|
||||
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
|
||||
).to(device, dtype)
|
||||
else:
|
||||
control_video_latents = torch.zeros_like(latents).to(device, dtype)
|
||||
control_latents = (
|
||||
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
|
||||
).to(device, dtype)
|
||||
|
||||
if ref_image is not None:
|
||||
batch_size, channels, num_frames, height_video, width_video = ref_image.shape
|
||||
ref_image = self.image_processor.preprocess(
|
||||
ref_image.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video),
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
ref_image = ref_image.to(dtype=torch.float32)
|
||||
ref_image = ref_image.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
|
||||
|
||||
ref_image_latents = self.prepare_control_latents(
|
||||
None,
|
||||
ref_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
self.do_classifier_free_guidance,
|
||||
)[1]
|
||||
|
||||
ref_image_latents_conv_in = torch.zeros_like(latents)
|
||||
if latents.size()[2] != 1:
|
||||
ref_image_latents_conv_in[:, :, :1] = ref_image_latents
|
||||
ref_image_latents_conv_in = (
|
||||
torch.cat([ref_image_latents_conv_in] * 2)
|
||||
if self.do_classifier_free_guidance
|
||||
else ref_image_latents_conv_in
|
||||
).to(device, dtype)
|
||||
control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
|
||||
else:
|
||||
ref_image_latents_conv_in = torch.zeros_like(latents)
|
||||
ref_image_latents_conv_in = (
|
||||
torch.cat([ref_image_latents_conv_in] * 2)
|
||||
if self.do_classifier_free_guidance
|
||||
else ref_image_latents_conv_in
|
||||
).to(device, dtype)
|
||||
control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
|
||||
|
||||
# To latents.device
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device=device)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
if hasattr(self.scheduler, "scale_model_input"):
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
|
||||
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
|
||||
dtype=latent_model_input.dtype
|
||||
)
|
||||
# predict the noise residual
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input,
|
||||
t_expand,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
control_latents=control_latents,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if noise_pred.size()[1] != self.vae.config.latent_channels:
|
||||
noise_pred, _ = noise_pred.chunk(2, dim=1)
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# Convert to tensor
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return EasyAnimatePipelineOutput(frames=video)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class EasyAnimatePipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for EasyAnimate pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
@@ -405,28 +405,23 @@ class FluxPipeline(
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
|
||||
for single_ip_adapter_image in ip_adapter_image:
|
||||
for single_ip_adapter_image, image_proj_layer in zip(
|
||||
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
||||
|
||||
image_embeds.append(single_image_embeds[None, :])
|
||||
else:
|
||||
if not isinstance(ip_adapter_image_embeds, list):
|
||||
ip_adapter_image_embeds = [ip_adapter_image_embeds]
|
||||
|
||||
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
|
||||
)
|
||||
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
ip_adapter_image_embeds = []
|
||||
for single_image_embeds in image_embeds:
|
||||
for i, single_image_embeds in enumerate(image_embeds):
|
||||
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = single_image_embeds.to(device=device)
|
||||
ip_adapter_image_embeds.append(single_image_embeds)
|
||||
@@ -694,7 +689,7 @@ class FluxPipeline(
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
@@ -877,13 +872,10 @@ class FluxPipeline(
|
||||
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
||||
):
|
||||
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
|
||||
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
||||
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
||||
):
|
||||
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
||||
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {}
|
||||
|
||||
@@ -660,7 +660,7 @@ class FluxControlPipeline(
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user