Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f319e27318 | |||
| 30c977d1f5 | |||
| f0fa17dd8e | |||
| c726d02beb | |||
| a68503f221 | |||
| 9d50f7eec1 | |||
| fda1531d8a | |||
| cf6e0407e0 | |||
| 1c000d46e1 | |||
| 08bf754507 | |||
| 2f23437618 | |||
| 2523390c26 | |||
| 279de3c3ff | |||
| 8e14535708 | |||
| 0bee4d336b | |||
| 42f25d601a | |||
| 33c5d125cb | |||
| aa1f00fd01 | |||
| d95b993427 | |||
| 1d480298c1 | |||
| b2323aa2b7 | |||
| 37e9d695af | |||
| a402431de0 | |||
| b99b1617cf | |||
| 3e4a6bd2d4 | |||
| c827e94da0 | |||
| 44f6b859bf | |||
| ac7ff7d4a3 |
@@ -31,7 +31,6 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install pandas peft
|
||||
|
||||
@@ -20,7 +20,7 @@ env:
|
||||
|
||||
jobs:
|
||||
test-build-docker-images:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
if: github.event_name == 'pull_request'
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
if: steps.file_changes.outputs.all != ''
|
||||
|
||||
build-and-push-docker-images:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
if: github.event_name != 'pull_request'
|
||||
|
||||
permissions:
|
||||
@@ -73,13 +73,13 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ env.REGISTRY }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
|
||||
@@ -70,7 +70,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
@@ -131,7 +130,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
@@ -202,7 +200,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
@@ -262,7 +259,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
@@ -32,7 +32,6 @@ jobs:
|
||||
fetch-depth: 0
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
- name: Environment
|
||||
@@ -89,7 +88,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pip install -e [quality,test]
|
||||
python -m pip install accelerate
|
||||
@@ -147,7 +145,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pip install -e [quality,test]
|
||||
|
||||
|
||||
@@ -89,11 +89,10 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
if [ "${{ matrix.lib-versions }}" == "main" ]; then
|
||||
python -m uv pip install -U peft@git+https://github.com/huggingface/peft.git
|
||||
python -m pip install -U peft@git+https://github.com/huggingface/peft.git
|
||||
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
else
|
||||
|
||||
@@ -116,7 +116,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate
|
||||
@@ -205,7 +204,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
|
||||
|
||||
@@ -71,7 +71,6 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
@@ -121,7 +120,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
@@ -171,11 +169,10 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
python -m pip install -U peft@git+https://github.com/huggingface/peft.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -222,7 +219,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
@@ -270,7 +266,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
@@ -68,7 +68,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
|
||||
|
||||
@@ -25,6 +25,6 @@ jobs:
|
||||
|
||||
- name: Update metadata
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.SAYAK_HF_TOKEN }}
|
||||
run: |
|
||||
python utils/update_metadata.py --commit_sha ${{ github.sha }}
|
||||
|
||||
@@ -12,6 +12,7 @@ RUN apt update && \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
|
||||
@@ -12,6 +12,7 @@ RUN apt update && \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
|
||||
@@ -12,6 +12,7 @@ RUN apt update && \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
|
||||
@@ -12,6 +12,7 @@ RUN apt update && \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
|
||||
@@ -71,7 +71,7 @@
|
||||
- local: using-diffusers/control_brightness
|
||||
title: Control image brightness
|
||||
- local: using-diffusers/weighted_prompts
|
||||
title: Prompt weighting
|
||||
title: Prompt techniques
|
||||
- local: using-diffusers/freeu
|
||||
title: Improve generation quality with FreeU
|
||||
title: Techniques
|
||||
@@ -86,6 +86,8 @@
|
||||
title: Kandinsky
|
||||
- local: using-diffusers/controlnet
|
||||
title: ControlNet
|
||||
- local: using-diffusers/t2i_adapter
|
||||
title: T2I-Adapter
|
||||
- local: using-diffusers/shap-e
|
||||
title: Shap-E
|
||||
- local: using-diffusers/diffedit
|
||||
@@ -170,6 +172,8 @@
|
||||
title: Token merging
|
||||
- local: optimization/deepcache
|
||||
title: DeepCache
|
||||
- local: optimization/tgate
|
||||
title: TGATE
|
||||
title: General optimizations
|
||||
- sections:
|
||||
- local: using-diffusers/stable_diffusion_jax_how_to
|
||||
@@ -280,6 +284,10 @@
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
title: ControlNet with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnetxs
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/ddim
|
||||
@@ -358,7 +366,7 @@
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
|
||||
- local: api/pipelines/stable_diffusion/adapter
|
||||
title: Stable Diffusion T2I-Adapter
|
||||
title: T2I-Adapter
|
||||
- local: api/pipelines/stable_diffusion/gligen
|
||||
title: GLIGEN (Grounded Language-to-Image Generation)
|
||||
title: Stable Diffusion
|
||||
|
||||
+24
-1
@@ -1,3 +1,15 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet-XS
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
@@ -12,5 +24,16 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
|
||||
|
||||
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
|
||||
|
||||
<Tip>
|
||||
|
||||
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionControlNetXSPipeline
|
||||
[[autodoc]] StableDiffusionControlNetXSPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
+31
-1
@@ -1,3 +1,15 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet-XS with Stable Diffusion XL
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
@@ -12,4 +24,22 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
|
||||
|
||||
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
|
||||
|
||||
> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
<Tip warning={true}>
|
||||
|
||||
🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionXLControlNetXSPipeline
|
||||
[[autodoc]] StableDiffusionXLControlNetXSPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
|
||||
@@ -10,9 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Text-to-Image Generation with Adapter Conditioning
|
||||
|
||||
## Overview
|
||||
# T2I-Adapter
|
||||
|
||||
[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.08453) by Chong Mou, Xintao Wang, Liangbin Xie, Jian Zhang, Zhongang Qi, Ying Shan, Xiaohu Qie.
|
||||
|
||||
@@ -24,236 +22,26 @@ The abstract of the paper is the following:
|
||||
|
||||
This model was contributed by the community contributor [HimariO](https://github.com/HimariO) ❤️ .
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
|
||||
| [StableDiffusionXLAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning on StableDiffusion-XL* | -
|
||||
|
||||
## Usage example with the base model of StableDiffusion-1.4/1.5
|
||||
|
||||
In the following we give a simple example of how to use a *T2I-Adapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
|
||||
All adapters use the same pipeline.
|
||||
|
||||
1. Images are first converted into the appropriate *control image* format.
|
||||
2. The *control image* and *prompt* are passed to the [`StableDiffusionAdapterPipeline`].
|
||||
|
||||
Let's have a look at a simple example using the [Color Adapter](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1).
|
||||
|
||||
```python
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png")
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
Then we can create our color palette by simply resizing it to 8 by 8 pixels and then scaling it back to original size.
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
|
||||
color_palette = image.resize((8, 8))
|
||||
color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST)
|
||||
```
|
||||
|
||||
Let's take a look at the processed image.
|
||||
|
||||

|
||||
|
||||
|
||||
Next, create the adapter pipeline
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
|
||||
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionAdapterPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
Finally, pass the prompt and control image to the pipeline
|
||||
|
||||
```py
|
||||
# fix the random seed, so you will get the same result as the example
|
||||
generator = torch.Generator("cuda").manual_seed(7)
|
||||
|
||||
out_image = pipe(
|
||||
"At night, glowing cubes in front of the beach",
|
||||
image=color_palette,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
make_image_grid([image, color_palette, out_image], rows=1, cols=3)
|
||||
```
|
||||
|
||||

|
||||
|
||||
## Usage example with the base model of StableDiffusion-XL
|
||||
|
||||
In the following we give a simple example of how to use a *T2I-Adapter* checkpoint with Diffusers for inference based on StableDiffusion-XL.
|
||||
All adapters use the same pipeline.
|
||||
|
||||
1. Images are first downloaded into the appropriate *control image* format.
|
||||
2. The *control image* and *prompt* are passed to the [`StableDiffusionXLAdapterPipeline`].
|
||||
|
||||
Let's have a look at a simple example using the [Sketch Adapter](https://huggingface.co/Adapter/t2iadapter/tree/main/sketch_sdxl_1.0).
|
||||
|
||||
```python
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
sketch_image = load_image("https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png").convert("L")
|
||||
```
|
||||
|
||||

|
||||
|
||||
Then, create the adapter pipeline
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import (
|
||||
T2IAdapter,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
DDPMScheduler
|
||||
)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
adapter = T2IAdapter.from_pretrained("Adapter/t2iadapter", subfolder="sketch_sdxl_1.0", torch_dtype=torch.float16, adapter_type="full_adapter_xl")
|
||||
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
model_id, adapter=adapter, safety_checker=None, torch_dtype=torch.float16, variant="fp16", scheduler=scheduler
|
||||
)
|
||||
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
Finally, pass the prompt and control image to the pipeline
|
||||
|
||||
```py
|
||||
# fix the random seed, so you will get the same result as the example
|
||||
generator = torch.Generator().manual_seed(42)
|
||||
|
||||
sketch_image_out = pipe(
|
||||
prompt="a photo of a dog in real world, high quality",
|
||||
negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
|
||||
image=sketch_image,
|
||||
generator=generator,
|
||||
guidance_scale=7.5
|
||||
).images[0]
|
||||
make_image_grid([sketch_image, sketch_image_out], rows=1, cols=2)
|
||||
```
|
||||
|
||||

|
||||
|
||||
## Available checkpoints
|
||||
|
||||
Non-diffusers checkpoints can be found under [TencentARC/T2I-Adapter](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models).
|
||||
|
||||
### T2I-Adapter with Stable Diffusion 1.4
|
||||
|
||||
| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
|
||||
|---|---|---|---|
|
||||
|[TencentARC/t2iadapter_color_sd14v1](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1)<br/> *Trained with spatial color palette* | An image with 8x8 color palette.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_sample_output.png"/></a>|
|
||||
|[TencentARC/t2iadapter_canny_sd14v1](https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1)<br/> *Trained with canny edge detection* | A monochrome image with white edges on a black background.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/canny_sample_output.png"/></a>|
|
||||
|[TencentARC/t2iadapter_sketch_sd14v1](https://huggingface.co/TencentARC/t2iadapter_sketch_sd14v1)<br/> *Trained with [PidiNet](https://github.com/zhuoinoulu/pidinet) edge detection* | A hand-drawn monochrome image with white outlines on a black background.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_input.png"><img width="64" style="margin:0;padding:0;" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/sketch_sample_output.png"/></a>|
|
||||
|[TencentARC/t2iadapter_depth_sd14v1](https://huggingface.co/TencentARC/t2iadapter_depth_sd14v1)<br/> *Trained with Midas depth estimation* | A grayscale image with black representing deep areas and white representing shallow areas.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_output.png"/></a>|
|
||||
|[TencentARC/t2iadapter_openpose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_openpose_sd14v1)<br/> *Trained with OpenPose bone image* | A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/openpose_sample_output.png"/></a>|
|
||||
|[TencentARC/t2iadapter_keypose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_keypose_sd14v1)<br/> *Trained with mmpose skeleton image* | A [mmpose skeleton](https://github.com/open-mmlab/mmpose) image.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_output.png"/></a>|
|
||||
|[TencentARC/t2iadapter_seg_sd14v1](https://huggingface.co/TencentARC/t2iadapter_seg_sd14v1)<br/>*Trained with semantic segmentation* | An [custom](https://github.com/TencentARC/T2I-Adapter/discussions/25) segmentation protocol image.|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_input.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_input.png"/></a>|<a href="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_output.png"><img width="64" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/seg_sample_output.png"/></a> |
|
||||
|[TencentARC/t2iadapter_canny_sd15v2](https://huggingface.co/TencentARC/t2iadapter_canny_sd15v2)||
|
||||
|[TencentARC/t2iadapter_depth_sd15v2](https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2)||
|
||||
|[TencentARC/t2iadapter_sketch_sd15v2](https://huggingface.co/TencentARC/t2iadapter_sketch_sd15v2)||
|
||||
|[TencentARC/t2iadapter_zoedepth_sd15v1](https://huggingface.co/TencentARC/t2iadapter_zoedepth_sd15v1)||
|
||||
|[Adapter/t2iadapter, subfolder='sketch_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/sketch_sdxl_1.0)||
|
||||
|[Adapter/t2iadapter, subfolder='canny_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/canny_sdxl_1.0)||
|
||||
|[Adapter/t2iadapter, subfolder='openpose_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/openpose_sdxl_1.0)||
|
||||
|
||||
## Combining multiple adapters
|
||||
|
||||
[`MultiAdapter`] can be used for applying multiple conditionings at once.
|
||||
|
||||
Here we use the keypose adapter for the character posture and the depth adapter for creating the scene.
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
|
||||
cond_keypose = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"
|
||||
)
|
||||
cond_depth = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"
|
||||
)
|
||||
cond = [cond_keypose, cond_depth]
|
||||
|
||||
prompt = ["A man walking in an office room with a nice view"]
|
||||
```
|
||||
|
||||
The two control images look as such:
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
`MultiAdapter` combines keypose and depth adapters.
|
||||
|
||||
`adapter_conditioning_scale` balances the relative influence of the different adapters.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionAdapterPipeline, MultiAdapter, T2IAdapter
|
||||
|
||||
adapters = MultiAdapter(
|
||||
[
|
||||
T2IAdapter.from_pretrained("TencentARC/t2iadapter_keypose_sd14v1"),
|
||||
T2IAdapter.from_pretrained("TencentARC/t2iadapter_depth_sd14v1"),
|
||||
]
|
||||
)
|
||||
adapters = adapters.to(torch.float16)
|
||||
|
||||
pipe = StableDiffusionAdapterPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
torch_dtype=torch.float16,
|
||||
adapter=adapters,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe(prompt, cond, adapter_conditioning_scale=[0.8, 0.8]).images[0]
|
||||
make_image_grid([cond_keypose, cond_depth, image], rows=1, cols=3)
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
## T2I-Adapter vs ControlNet
|
||||
|
||||
T2I-Adapter is similar to [ControlNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet).
|
||||
T2I-Adapter uses a smaller auxiliary network which is only run once for the entire diffusion process.
|
||||
However, T2I-Adapter performs slightly worse than ControlNet.
|
||||
|
||||
## StableDiffusionAdapterPipeline
|
||||
|
||||
[[autodoc]] StableDiffusionAdapterPipeline
|
||||
- all
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- all
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
## StableDiffusionXLAdapterPipeline
|
||||
|
||||
[[autodoc]] StableDiffusionXLAdapterPipeline
|
||||
- all
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
- all
|
||||
- __call__
|
||||
- enable_attention_slicing
|
||||
- disable_attention_slicing
|
||||
- enable_vae_slicing
|
||||
- disable_vae_slicing
|
||||
- enable_xformers_memory_efficient_attention
|
||||
- disable_xformers_memory_efficient_attention
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
# T-GATE
|
||||
|
||||
[T-GATE](https://github.com/HaozheLiu-ST/T-GATE/tree/main) accelerates inference for [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [PixArt](../api/pipelines/pixart), and [Latency Consistency Model](../api/pipelines/latent_consistency_models.md) pipelines by skipping the cross-attention calculation once it converges. This method doesn't require any additional training and it can speed up inference from 10-50%. T-GATE is also compatible with other optimization methods like [DeepCache](./deepcache).
|
||||
|
||||
Before you begin, make sure you install T-GATE.
|
||||
|
||||
```bash
|
||||
pip install tgate
|
||||
pip install -U pytorch diffusers transformers accelerate DeepCache
|
||||
```
|
||||
|
||||
|
||||
To use T-GATE with a pipeline, you need to use its corresponding loader.
|
||||
|
||||
| Pipeline | T-GATE Loader |
|
||||
|---|---|
|
||||
| PixArt | TgatePixArtLoader |
|
||||
| Stable Diffusion XL | TgateSDXLLoader |
|
||||
| Stable Diffusion XL + DeepCache | TgateSDXLDeepCacheLoader |
|
||||
| Stable Diffusion | TgateSDLoader |
|
||||
| Stable Diffusion + DeepCache | TgateSDDeepCacheLoader |
|
||||
|
||||
Next, create a `TgateLoader` with a pipeline, the gate step (the time step to stop calculating the cross attention), and the number of inference steps. Then call the `tgate` method on the pipeline with a prompt, gate step, and the number of inference steps.
|
||||
|
||||
Let's see how to enable this for several different pipelines.
|
||||
|
||||
<hfoptions id="pipelines">
|
||||
<hfoption id="PixArt">
|
||||
|
||||
Accelerate `PixArtAlphaPipeline` with T-GATE:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import PixArtAlphaPipeline
|
||||
from tgate import TgatePixArtLoader
|
||||
|
||||
pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
|
||||
pipe = TgatePixArtLoader(
|
||||
pipe,
|
||||
gate_step=8,
|
||||
num_inference_steps=25,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe.tgate(
|
||||
"An alpaca made of colorful building blocks, cyberpunk.",
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step,
|
||||
).images[0]
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="Stable Diffusion XL">
|
||||
|
||||
Accelerate `StableDiffusionXLPipeline` with T-GATE:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
from tgate import TgateSDXLLoader
|
||||
gate_step = 10
|
||||
inference_step = 25
|
||||
pipe = TgateSDXLLoader(
|
||||
pipe,
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe.tgate(
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step
|
||||
).images[0]
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="StableDiffusionXL with DeepCache">
|
||||
|
||||
Accelerate `StableDiffusionXLPipeline` with [DeepCache](https://github.com/horseee/DeepCache) and T-GATE:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
from tgate import TgateSDXLDeepCacheLoader
|
||||
gate_step = 10
|
||||
inference_step = 25
|
||||
pipe = TgateSDXLDeepCacheLoader(
|
||||
pipe,
|
||||
cache_interval=3,
|
||||
cache_branch_id=0,
|
||||
).to("cuda")
|
||||
|
||||
image = pipe.tgate(
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step
|
||||
).images[0]
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="Latent Consistency Model">
|
||||
|
||||
Accelerate `latent-consistency/lcm-sdxl` with T-GATE:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import UNet2DConditionModel, LCMScheduler
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"latent-consistency/lcm-sdxl",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
unet=unet,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
from tgate import TgateSDXLLoader
|
||||
gate_step = 1
|
||||
inference_step = 4
|
||||
pipe = TgateSDXLLoader(
|
||||
pipe,
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step,
|
||||
lcm=True
|
||||
).to("cuda")
|
||||
|
||||
image = pipe.tgate(
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
|
||||
gate_step=gate_step,
|
||||
num_inference_steps=inference_step
|
||||
).images[0]
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
T-GATE also supports [`StableDiffusionPipeline`] and [PixArt-alpha/PixArt-LCM-XL-2-1024-MS](https://hf.co/PixArt-alpha/PixArt-LCM-XL-2-1024-MS).
|
||||
|
||||
## Benchmarks
|
||||
| Model | MACs | Param | Latency | Zero-shot 10K-FID on MS-COCO |
|
||||
|-----------------------|----------|-----------|---------|---------------------------|
|
||||
| SD-1.5 | 16.938T | 859.520M | 7.032s | 23.927 |
|
||||
| SD-1.5 w/ T-GATE | 9.875T | 815.557M | 4.313s | 20.789 |
|
||||
| SD-2.1 | 38.041T | 865.785M | 16.121s | 22.609 |
|
||||
| SD-2.1 w/ T-GATE | 22.208T | 815.433 M | 9.878s | 19.940 |
|
||||
| SD-XL | 149.438T | 2.570B | 53.187s | 24.628 |
|
||||
| SD-XL w/ T-GATE | 84.438T | 2.024B | 27.932s | 22.738 |
|
||||
| Pixart-Alpha | 107.031T | 611.350M | 61.502s | 38.669 |
|
||||
| Pixart-Alpha w/ T-GATE | 65.318T | 462.585M | 37.867s | 35.825 |
|
||||
| DeepCache (SD-XL) | 57.888T | - | 19.931s | 23.755 |
|
||||
| DeepCache w/ T-GATE | 43.868T | - | 14.666s | 23.999 |
|
||||
| LCM (SD-XL) | 11.955T | 2.570B | 3.805s | 25.044 |
|
||||
| LCM w/ T-GATE | 11.171T | 2.024B | 3.533s | 25.028 |
|
||||
| LCM (Pixart-Alpha) | 8.563T | 611.350M | 4.733s | 36.086 |
|
||||
| LCM w/ T-GATE | 7.623T | 462.585M | 4.543s | 37.048 |
|
||||
|
||||
The latency is tested on an NVIDIA 1080TI, MACs and Params are calculated with [calflops](https://github.com/MrYxJ/calculate-flops.pytorch), and the FID is calculated with [PytorchFID](https://github.com/mseitzer/pytorch-fid).
|
||||
@@ -52,6 +52,76 @@ To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](h
|
||||
|
||||
</Tip>
|
||||
|
||||
### Device placement
|
||||
|
||||
> [!WARNING]
|
||||
> This feature is experimental and its APIs might change in the future.
|
||||
|
||||
With Accelerate, you can use the `device_map` to determine how to distribute the models of a pipeline across multiple devices. This is useful in situations where you have more than one GPU.
|
||||
|
||||
For example, if you have two 8GB GPUs, then using [`~DiffusionPipeline.enable_model_cpu_offload`] may not work so well because:
|
||||
|
||||
* it only works on a single GPU
|
||||
* a single model might not fit on a single GPU ([`~DiffusionPipeline.enable_sequential_cpu_offload`] might work but it will be extremely slow and it is also limited to a single GPU)
|
||||
|
||||
To make use of both GPUs, you can use the "balanced" device placement strategy which splits the models across all available GPUs.
|
||||
|
||||
> [!WARNING]
|
||||
> Only the "balanced" strategy is supported at the moment, and we plan to support additional mapping strategies in the future.
|
||||
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
|
||||
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, device_map="balanced"
|
||||
)
|
||||
image = pipeline("a dog").images[0]
|
||||
image
|
||||
```
|
||||
|
||||
You can also pass a dictionary to enforce the maximum GPU memory that can be used on each device:
|
||||
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
max_memory = {0:"1GB", 1:"1GB"}
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
device_map="balanced",
|
||||
+ max_memory=max_memory
|
||||
)
|
||||
image = pipeline("a dog").images[0]
|
||||
image
|
||||
```
|
||||
|
||||
If a device is not present in `max_memory`, then it will be completely ignored and will not participate in the device placement.
|
||||
|
||||
By default, Diffusers uses the maximum memory of all devices. If the models don't fit on the GPUs, they are offloaded to the CPU. If the CPU doesn't have enough memory, then you might see an error. In that case, you could defer to using [`~DiffusionPipeline.enable_sequential_cpu_offload`] and [`~DiffusionPipeline.enable_model_cpu_offload`].
|
||||
|
||||
Call [`~DiffusionPipeline.reset_device_map`] to reset the `device_map` of a pipeline. This is also necessary if you want to use methods like `to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
|
||||
|
||||
```py
|
||||
pipeline.reset_device_map()
|
||||
```
|
||||
|
||||
Once a pipeline has been device-mapped, you can also access its device map via `hf_device_map`:
|
||||
|
||||
```py
|
||||
print(pipeline.hf_device_map)
|
||||
```
|
||||
|
||||
An example device map would look like so:
|
||||
|
||||
|
||||
```bash
|
||||
{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
|
||||
```
|
||||
|
||||
## PyTorch Distributed
|
||||
|
||||
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.
|
||||
|
||||
@@ -0,0 +1,219 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# T2I-Adapter
|
||||
|
||||
[T2I-Adapter](https://hf.co/papers/2302.08453) is a lightweight adapter for controlling and providing more accurate
|
||||
structure guidance for text-to-image models. It works by learning an alignment between the internal knowledge of the
|
||||
text-to-image model and an external control signal, such as edge detection or depth estimation.
|
||||
|
||||
The T2I-Adapter design is simple, the condition is passed to four feature extraction blocks and three downsample
|
||||
blocks. This makes it fast and easy to train different adapters for different conditions which can be plugged into the
|
||||
text-to-image model. T2I-Adapter is similar to [ControlNet](controlnet) except it is smaller (~77M parameters) and
|
||||
faster because it only runs once during the diffusion process. The downside is that performance may be slightly worse
|
||||
than ControlNet.
|
||||
|
||||
This guide will show you how to use T2I-Adapter with different Stable Diffusion models and how you can compose multiple
|
||||
T2I-Adapters to impose more than one condition.
|
||||
|
||||
> [!TIP]
|
||||
> There are several T2I-Adapters available for different conditions, such as color palette, depth, sketch, pose, and
|
||||
> segmentation. Check out the [TencentARC](https://hf.co/TencentARC) repository to try them out!
|
||||
|
||||
Before you begin, make sure you have the following libraries installed.
|
||||
|
||||
```py
|
||||
# uncomment to install the necessary libraries in Colab
|
||||
#!pip install -q diffusers accelerate controlnet-aux==0.0.7
|
||||
```
|
||||
|
||||
## Text-to-image
|
||||
|
||||
Text-to-image models rely on a prompt to generate an image, but sometimes, text alone may not be enough to provide more
|
||||
accurate structural guidance. T2I-Adapter allows you to provide an additional control image to guide the generation
|
||||
process. For example, you can provide a canny image (a white outline of an image on a black background) to guide the
|
||||
model to generate an image with a similar structure.
|
||||
|
||||
<hfoptions id="stablediffusion">
|
||||
<hfoption id="Stable Diffusion 1.5">
|
||||
|
||||
Create a canny image with the [opencv-library](https://github.com/opencv/opencv-python).
|
||||
|
||||
```py
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers.utils import load_image
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png")
|
||||
image = np.array(image)
|
||||
|
||||
low_threshold = 100
|
||||
high_threshold = 200
|
||||
|
||||
image = cv2.Canny(image, low_threshold, high_threshold)
|
||||
image = Image.fromarray(image)
|
||||
```
|
||||
|
||||
Now load a T2I-Adapter conditioned on [canny images](https://hf.co/TencentARC/t2iadapter_canny_sd15v2) and pass it to
|
||||
the [`StableDiffusionAdapterPipeline`].
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
|
||||
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_canny_sd15v2", torch_dtype=torch.float16)
|
||||
pipeline = StableDiffusionAdapterPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
```
|
||||
|
||||
Finally, pass your prompt and control image to the pipeline.
|
||||
|
||||
```py
|
||||
generator = torch.Generator("cuda").manual_seed(0)
|
||||
|
||||
image = pipeline(
|
||||
prompt="cinematic photo of a plush and soft midcentury style rug on a wooden floor, 35mm photograph, film, professional, 4k, highly detailed",
|
||||
image=image,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-sd1.5.png"/>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Stable Diffusion XL">
|
||||
|
||||
Create a canny image with the [controlnet-aux](https://github.com/huggingface/controlnet_aux) library.
|
||||
|
||||
```py
|
||||
from controlnet_aux.canny import CannyDetector
|
||||
from diffusers.utils import load_image
|
||||
|
||||
canny_detector = CannyDetector()
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png")
|
||||
image = canny_detector(image, detect_resolution=384, image_resolution=1024)
|
||||
```
|
||||
|
||||
Now load a T2I-Adapter conditioned on [canny images](https://hf.co/TencentARC/t2i-adapter-canny-sdxl-1.0) and pass it
|
||||
to the [`StableDiffusionXLAdapterPipeline`].
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL
|
||||
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
|
||||
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-canny-sdxl-1.0", torch_dtype=torch.float16)
|
||||
pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
adapter=adapter,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
```
|
||||
|
||||
Finally, pass your prompt and control image to the pipeline.
|
||||
|
||||
```py
|
||||
generator = torch.Generator("cuda").manual_seed(0)
|
||||
|
||||
image = pipeline(
|
||||
prompt="cinematic photo of a plush and soft midcentury style rug on a wooden floor, 35mm photograph, film, professional, 4k, highly detailed",
|
||||
image=image,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-sdxl.png"/>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## MultiAdapter
|
||||
|
||||
T2I-Adapters are also composable, allowing you to use more than one adapter to impose multiple control conditions on an
|
||||
image. For example, you can use a pose map to provide structural control and a depth map for depth control. This is
|
||||
enabled by the [`MultiAdapter`] class.
|
||||
|
||||
Let's condition a text-to-image model with a pose and depth adapter. Create and place your depth and pose image and in a list.
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pose_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"
|
||||
)
|
||||
depth_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"
|
||||
)
|
||||
cond = [pose_image, depth_image]
|
||||
prompt = ["Santa Claus walking into an office room with a beautiful city view"]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">depth image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">pose image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Load the corresponding pose and depth adapters as a list in the [`MultiAdapter`] class.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionAdapterPipeline, MultiAdapter, T2IAdapter
|
||||
|
||||
adapters = MultiAdapter(
|
||||
[
|
||||
T2IAdapter.from_pretrained("TencentARC/t2iadapter_keypose_sd14v1"),
|
||||
T2IAdapter.from_pretrained("TencentARC/t2iadapter_depth_sd14v1"),
|
||||
]
|
||||
)
|
||||
adapters = adapters.to(torch.float16)
|
||||
```
|
||||
|
||||
Finally, load a [`StableDiffusionAdapterPipeline`] with the adapters, and pass your prompt and conditioned images to
|
||||
it. Use the [`adapter_conditioning_scale`] to adjust the weight of each adapter on the image.
|
||||
|
||||
```py
|
||||
pipeline = StableDiffusionAdapterPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
torch_dtype=torch.float16,
|
||||
adapter=adapters,
|
||||
).to("cuda")
|
||||
|
||||
image = pipeline(prompt, cond, adapter_conditioning_scale=[0.7, 0.7]).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/t2i-multi.png"/>
|
||||
</div>
|
||||
@@ -10,10 +10,209 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Prompt weighting
|
||||
# Prompt techniques
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
Prompts are important because they describe what you want a diffusion model to generate. The best prompts are detailed, specific, and well-structured to help the model realize your vision. But crafting a great prompt takes time and effort and sometimes it may not be enough because language and words can be imprecise. This is where you need to boost your prompt with other techniques, such as prompt enhancing and prompt weighting, to get the results you want.
|
||||
|
||||
This guide will show you how you can use these prompt techniques to generate high-quality images with lower effort and adjust the weight of certain keywords in a prompt.
|
||||
|
||||
## Prompt engineering
|
||||
|
||||
> [!TIP]
|
||||
> This is not an exhaustive guide on prompt engineering, but it will help you understand the necessary parts of a good prompt. We encourage you to continue experimenting with different prompts and combine them in new ways to see what works best. As you write more prompts, you'll develop an intuition for what works and what doesn't!
|
||||
|
||||
New diffusion models do a pretty good job of generating high-quality images from a basic prompt, but it is still important to create a well-written prompt to get the best results. Here are a few tips for writing a good prompt:
|
||||
|
||||
1. What is the image *medium*? Is it a photo, a painting, a 3D illustration, or something else?
|
||||
2. What is the image *subject*? Is it a person, animal, object, or scene?
|
||||
3. What *details* would you like to see in the image? This is where you can get really creative and have a lot of fun experimenting with different words to bring your image to life. For example, what is the lighting like? What is the vibe and aesthetic? What kind of art or illustration style are you looking for? The more specific and precise words you use, the better the model will understand what you want to generate.
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/plain-prompt.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">"A photo of a banana-shaped couch in a living room"</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">"A vibrant yellow banana-shaped couch sits in a cozy living room, its curve cradling a pile of colorful cushions. on the wooden floor, a patterned rug adds a touch of eclectic charm, and a potted plant sits in the corner, reaching towards the sunlight filtering through the windows"</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Prompt enhancing with GPT2
|
||||
|
||||
Prompt enhancing is a technique for quickly improving prompt quality without spending too much effort constructing one. It uses a model like GPT2 pretrained on Stable Diffusion text prompts to automatically enrich a prompt with additional important keywords to generate high-quality images.
|
||||
|
||||
The technique works by curating a list of specific keywords and forcing the model to generate those words to enhance the original prompt. This way, your prompt can be "a cat" and GPT2 can enhance the prompt to "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain quality sharp focus beautiful detailed intricate stunning amazing epic".
|
||||
|
||||
> [!TIP]
|
||||
> You should also use a [*offset noise*](https://www.crosslabs.org//blog/diffusion-with-offset-noise) LoRA to improve the contrast in bright and dark images and create better lighting overall. This [LoRA](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors) is available from [stabilityai/stable-diffusion-xl-base-1.0](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0).
|
||||
|
||||
Start by defining certain styles and a list of words (you can check out a more comprehensive list of [words](https://hf.co/LykosAI/GPT-Prompt-Expansion-Fooocus-v2/blob/main/positive.txt) and [styles](https://github.com/lllyasviel/Fooocus/tree/main/sdxl_styles) used by Fooocus) to enhance a prompt with.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import GenerationConfig, GPT2LMHeadModel, GPT2Tokenizer, LogitsProcessor, LogitsProcessorList
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
styles = {
|
||||
"cinematic": "cinematic film still of {prompt}, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
|
||||
"anime": "anime artwork of {prompt}, anime style, key visual, vibrant, studio anime, highly detailed",
|
||||
"photographic": "cinematic photo of {prompt}, 35mm photograph, film, professional, 4k, highly detailed",
|
||||
"comic": "comic of {prompt}, graphic illustration, comic art, graphic novel art, vibrant, highly detailed",
|
||||
"lineart": "line art drawing {prompt}, professional, sleek, modern, minimalist, graphic, line art, vector graphics",
|
||||
"pixelart": " pixel-art {prompt}, low-res, blocky, pixel art style, 8-bit graphics",
|
||||
}
|
||||
|
||||
words = [
|
||||
"aesthetic", "astonishing", "beautiful", "breathtaking", "composition", "contrasted", "epic", "moody", "enhanced",
|
||||
"exceptional", "fascinating", "flawless", "glamorous", "glorious", "illumination", "impressive", "improved",
|
||||
"inspirational", "magnificent", "majestic", "hyperrealistic", "smooth", "sharp", "focus", "stunning", "detailed",
|
||||
"intricate", "dramatic", "high", "quality", "perfect", "light", "ultra", "highly", "radiant", "satisfying",
|
||||
"soothing", "sophisticated", "stylish", "sublime", "terrific", "touching", "timeless", "wonderful", "unbelievable",
|
||||
"elegant", "awesome", "amazing", "dynamic", "trendy",
|
||||
]
|
||||
```
|
||||
|
||||
You may have noticed in the `words` list, there are certain words that can be paired together to create something more meaningful. For example, the words "high" and "quality" can be combined to create "high quality". Let's pair these words together and remove the words that can't be paired.
|
||||
|
||||
```py
|
||||
word_pairs = ["highly detailed", "high quality", "enhanced quality", "perfect composition", "dynamic light"]
|
||||
|
||||
def find_and_order_pairs(s, pairs):
|
||||
words = s.split()
|
||||
found_pairs = []
|
||||
for pair in pairs:
|
||||
pair_words = pair.split()
|
||||
if pair_words[0] in words and pair_words[1] in words:
|
||||
found_pairs.append(pair)
|
||||
words.remove(pair_words[0])
|
||||
words.remove(pair_words[1])
|
||||
|
||||
for word in words[:]:
|
||||
for pair in pairs:
|
||||
if word in pair.split():
|
||||
words.remove(word)
|
||||
break
|
||||
ordered_pairs = ", ".join(found_pairs)
|
||||
remaining_s = ", ".join(words)
|
||||
return ordered_pairs, remaining_s
|
||||
```
|
||||
|
||||
Next, implement a custom [`~transformers.LogitsProcessor`] class that assigns tokens in the `words` list a value of 0 and assigns tokens not in the `words` list a negative value so they aren't picked during generation. This way, generation is biased towards words in the `words` list. After a word from the list is used, it is also assigned a negative value so it isn't picked again.
|
||||
|
||||
```py
|
||||
class CustomLogitsProcessor(LogitsProcessor):
|
||||
def __init__(self, bias):
|
||||
super().__init__()
|
||||
self.bias = bias
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
if len(input_ids.shape) == 2:
|
||||
last_token_id = input_ids[0, -1]
|
||||
self.bias[last_token_id] = -1e10
|
||||
return scores + self.bias
|
||||
|
||||
word_ids = [tokenizer.encode(word, add_prefix_space=True)[0] for word in words]
|
||||
bias = torch.full((tokenizer.vocab_size,), -float("Inf")).to("cuda")
|
||||
bias[word_ids] = 0
|
||||
processor = CustomLogitsProcessor(bias)
|
||||
processor_list = LogitsProcessorList([processor])
|
||||
```
|
||||
|
||||
Combine the prompt and the `cinematic` style prompt defined in the `styles` dictionary earlier.
|
||||
|
||||
```py
|
||||
prompt = "a cat basking in the sun on a roof in Turkey"
|
||||
style = "cinematic"
|
||||
|
||||
prompt = styles[style].format(prompt=prompt)
|
||||
prompt
|
||||
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
|
||||
```
|
||||
|
||||
Load a GPT2 tokenizer and model from the [Gustavosta/MagicPrompt-Stable-Diffusion](https://huggingface.co/Gustavosta/MagicPrompt-Stable-Diffusion) checkpoint (this specific checkpoint is trained to generate prompts) to enhance the prompt.
|
||||
|
||||
```py
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
|
||||
model = GPT2LMHeadModel.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion", torch_dtype=torch.float16).to(
|
||||
"cuda"
|
||||
)
|
||||
model.eval()
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
token_count = inputs["input_ids"].shape[1]
|
||||
max_new_tokens = 50 - token_count
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
penalty_alpha=0.7,
|
||||
top_k=50,
|
||||
eos_token_id=model.config.eos_token_id,
|
||||
pad_token_id=model.config.eos_token_id,
|
||||
pad_token=model.config.pad_token_id,
|
||||
do_sample=True,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
max_new_tokens=max_new_tokens,
|
||||
generation_config=generation_config,
|
||||
logits_processor=proccesor_list,
|
||||
)
|
||||
```
|
||||
|
||||
Then you can combine the input prompt and the generated prompt. Feel free to take a look at what the generated prompt (`generated_part`) is, the word pairs that were found (`pairs`), and the remaining words (`words`). This is all packed together in the `enhanced_prompt`.
|
||||
|
||||
```py
|
||||
output_tokens = [tokenizer.decode(generated_id, skip_special_tokens=True) for generated_id in generated_ids]
|
||||
input_part, generated_part = output_tokens[0][: len(prompt)], output_tokens[0][len(prompt) :]
|
||||
pairs, words = find_and_order_pairs(generated_part, word_pairs)
|
||||
formatted_generated_part = pairs + ", " + words
|
||||
enhanced_prompt = input_part + ", " + formatted_generated_part
|
||||
enhanced_prompt
|
||||
["cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain quality sharp focus beautiful detailed intricate stunning amazing epic"]
|
||||
```
|
||||
|
||||
Finally, load a pipeline and the offset noise LoRA with a *low weight* to generate an image with the enhanced prompt.
|
||||
|
||||
```py
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
|
||||
pipeline.load_lora_weights(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
weight_name="sd_xl_offset_example-lora_1.0.safetensors",
|
||||
adapter_name="offset",
|
||||
)
|
||||
pipeline.set_adapters(["offset"], adapter_weights=[0.2])
|
||||
|
||||
image = pipeline(
|
||||
enhanced_prompt,
|
||||
width=1152,
|
||||
height=896,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=25,
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/non-enhanced-prompt.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">"a cat basking in the sun on a roof in Turkey"</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/enhanced-prompt.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Prompt weighting
|
||||
|
||||
Prompt weighting provides a way to emphasize or de-emphasize certain parts of a prompt, allowing for more control over the generated image. A prompt can include several concepts, which gets turned into contextualized text embeddings. The embeddings are used by the model to condition its cross-attention layers to generate an image (read the Stable Diffusion [blog post](https://huggingface.co/blog/stable_diffusion) to learn more about how it works).
|
||||
|
||||
Prompt weighting works by increasing or decreasing the scale of the text embedding vector that corresponds to its concept in the prompt because you may not necessarily want the model to focus on all concepts equally. The easiest way to prepare the prompt-weighted embeddings is to use [Compel](https://github.com/damian0815/compel), a text prompt-weighting and blending library. Once you have the prompt-weighted embeddings, you can pass them to any pipeline that has a [`prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) (and optionally [`negative_prompt_embeds`](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.negative_prompt_embeds)) parameter, such as [`StableDiffusionPipeline`], [`StableDiffusionControlNetPipeline`], and [`StableDiffusionXLPipeline`].
|
||||
@@ -55,7 +254,7 @@ image
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/compel/forest_0.png"/>
|
||||
</div>
|
||||
|
||||
## Weighting
|
||||
### Weighting
|
||||
|
||||
You'll notice there is no "ball" in the image! Let's use compel to upweight the concept of "ball" in the prompt. Create a [`Compel`](https://github.com/damian0815/compel/blob/main/doc/compel.md#compel-objects) object, and pass it a tokenizer and text encoder:
|
||||
|
||||
@@ -123,7 +322,7 @@ image
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-pos-neg.png"/>
|
||||
</div>
|
||||
|
||||
## Blending
|
||||
### Blending
|
||||
|
||||
You can also create a weighted *blend* of prompts by adding `.blend()` to a list of prompts and passing it some weights. Your blend may not always produce the result you expect because it breaks some assumptions about how the text encoder functions, so just have fun and experiment with it!
|
||||
|
||||
@@ -139,7 +338,7 @@ image
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-blend.png"/>
|
||||
</div>
|
||||
|
||||
## Conjunction
|
||||
### Conjunction
|
||||
|
||||
A conjunction diffuses each prompt independently and concatenates their results by their weighted sum. Add `.and()` to the end of a list of prompts to create a conjunction:
|
||||
|
||||
@@ -155,7 +354,7 @@ image
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-conj.png"/>
|
||||
</div>
|
||||
|
||||
## Textual inversion
|
||||
### Textual inversion
|
||||
|
||||
[Textual inversion](../training/text_inversion) is a technique for learning a specific concept from some images which you can use to generate new images conditioned on that concept.
|
||||
|
||||
@@ -195,7 +394,7 @@ image
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-text-inversion.png"/>
|
||||
</div>
|
||||
|
||||
## DreamBooth
|
||||
### DreamBooth
|
||||
|
||||
[DreamBooth](../training/dreambooth) is a technique for generating contextualized images of a subject given just a few images of the subject to train on. It is similar to textual inversion, but DreamBooth trains the full model whereas textual inversion only fine-tunes the text embeddings. This means you should use [`~DiffusionPipeline.from_pretrained`] to load the DreamBooth model (feel free to browse the [Stable Diffusion Dreambooth Concepts Library](https://huggingface.co/sd-dreambooth-library) for 100+ trained models):
|
||||
|
||||
@@ -221,7 +420,7 @@ image
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-dreambooth.png"/>
|
||||
</div>
|
||||
|
||||
## Stable Diffusion XL
|
||||
### Stable Diffusion XL
|
||||
|
||||
Stable Diffusion XL (SDXL) has two tokenizers and text encoders so it's usage is a bit different. To address this, you should pass both tokenizers and encoders to the `Compel` class:
|
||||
|
||||
|
||||
@@ -945,7 +945,7 @@ def main(args):
|
||||
|
||||
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from (online) unet
|
||||
target_unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
target_unet = UNet2DConditionModel.from_config(unet.config)
|
||||
target_unet.load_state_dict(unet.state_dict())
|
||||
target_unet.train()
|
||||
target_unet.requires_grad_(False)
|
||||
|
||||
@@ -1004,7 +1004,7 @@ def main(args):
|
||||
|
||||
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from (online) unet
|
||||
target_unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
target_unet = UNet2DConditionModel.from_config(unet.config)
|
||||
target_unet.load_state_dict(unet.state_dict())
|
||||
target_unet.train()
|
||||
target_unet.requires_grad_(False)
|
||||
|
||||
@@ -53,6 +53,9 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
@@ -64,6 +67,48 @@ DATASET_NAME_MAPPING = {
|
||||
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
|
||||
|
||||
|
||||
def log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
|
||||
tracker.log({"validation": wandb_table})
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
|
||||
parser.add_argument(
|
||||
@@ -411,11 +456,6 @@ def main():
|
||||
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -517,7 +557,8 @@ def main():
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
@@ -923,11 +964,6 @@ def main():
|
||||
and (args.validation_prompt is not None)
|
||||
and (epoch % args.validation_epochs == 0)
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
if args.use_ema:
|
||||
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
||||
ema_unet.store(unet.parameters())
|
||||
@@ -942,38 +978,14 @@ def main():
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
)
|
||||
|
||||
with autocast_ctx:
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(
|
||||
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
|
||||
)
|
||||
tracker.log({"validation": wandb_table})
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_unet.restore(unet.parameters())
|
||||
@@ -984,7 +996,6 @@ def main():
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
@@ -992,7 +1003,7 @@ def main():
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=unwrap_model(vae),
|
||||
unet=unet,
|
||||
unet=unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
@@ -1006,31 +1017,13 @@ def main():
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
if args.validation_prompt is not None:
|
||||
edited_images = []
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
with torch.autocast(str(accelerator.device).replace(":0", "")):
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
args.validation_prompt,
|
||||
image=original_image,
|
||||
num_inference_steps=20,
|
||||
image_guidance_scale=1.5,
|
||||
guidance_scale=7,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
|
||||
for edited_image in edited_images:
|
||||
wandb_table.add_data(
|
||||
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
|
||||
)
|
||||
tracker.log({"test": wandb_table})
|
||||
|
||||
if (args.val_image_url is not None) and (args.validation_prompt is not None):
|
||||
log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,58 +0,0 @@
|
||||
# !pip install opencv-python transformers accelerate
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from PIL import Image
|
||||
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
)
|
||||
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
|
||||
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
|
||||
parser.add_argument(
|
||||
"--image_path",
|
||||
type=str,
|
||||
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt = args.prompt
|
||||
negative_prompt = args.negative_prompt
|
||||
# download an image
|
||||
image = load_image(args.image_path)
|
||||
|
||||
# initialize the models and pipeline
|
||||
controlnet_conditioning_scale = args.controlnet_conditioning_scale
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# get canny image
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
num_inference_steps = args.num_inference_steps
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
image=canny_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0]
|
||||
image.save("cnxs_sd.canny.png")
|
||||
@@ -1,57 +0,0 @@
|
||||
# !pip install opencv-python transformers accelerate
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from PIL import Image
|
||||
from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
)
|
||||
parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
|
||||
parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
|
||||
parser.add_argument(
|
||||
"--image_path",
|
||||
type=str,
|
||||
default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
|
||||
)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
prompt = args.prompt
|
||||
negative_prompt = args.negative_prompt
|
||||
# download an image
|
||||
image = load_image(args.image_path)
|
||||
# initialize the models and pipeline
|
||||
controlnet_conditioning_scale = args.controlnet_conditioning_scale
|
||||
controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# get canny image
|
||||
image = np.array(image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
num_inference_steps = args.num_inference_steps
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
image=canny_image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0]
|
||||
image.save("cnxs_sdxl.canny.png")
|
||||
@@ -0,0 +1,15 @@
|
||||
# Scheduled Pseudo-Huber Loss for Diffusers
|
||||
|
||||
These are the modifications of to include the possibility of training text2image models with Scheduled Pseudo Huber loss, introduced in https://arxiv.org/abs/2403.16728. (https://github.com/kabachuha/SPHL-for-stable-diffusion)
|
||||
|
||||
## Why this might be useful?
|
||||
|
||||
- If you suspect that the part of the training dataset might be corrupted, and you don't want these outliers to distort the model's supposed output
|
||||
|
||||
- If you want to improve the aesthetic quality of pictures by helping the model disentangle concepts and be less influenced by another sorts of pictures.
|
||||
|
||||
See https://github.com/huggingface/diffusers/issues/7488 for the detailed description.
|
||||
|
||||
## Instructions
|
||||
|
||||
The same usage as in the case of the corresponding vanilla Diffusers scripts https://github.com/huggingface/diffusers/tree/main/examples
|
||||
+1518
File diff suppressed because it is too large
Load Diff
+1504
File diff suppressed because it is too large
Load Diff
+2078
File diff suppressed because it is too large
Load Diff
+1162
File diff suppressed because it is too large
Load Diff
+1051
File diff suppressed because it is too large
Load Diff
+1384
File diff suppressed because it is too large
Load Diff
+1394
File diff suppressed because it is too large
Load Diff
+5
-3
@@ -1,15 +1,17 @@
|
||||
[tool.ruff]
|
||||
line-length = 119
|
||||
|
||||
[tool.ruff.lint]
|
||||
# Never enforce `E501` (line length violations).
|
||||
ignore = ["C901", "E501", "E741", "F402", "F823"]
|
||||
select = ["C", "E", "F", "I", "W"]
|
||||
line-length = 119
|
||||
|
||||
# Ignore import violations in all `__init__.py` files.
|
||||
[tool.ruff.per-file-ignores]
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"__init__.py" = ["E402", "F401", "F403", "F811"]
|
||||
"src/diffusers/utils/dummy_*.py" = ["F401"]
|
||||
|
||||
[tool.ruff.isort]
|
||||
[tool.ruff.lint.isort]
|
||||
lines-after-imports = 2
|
||||
known-first-party = ["diffusers"]
|
||||
|
||||
|
||||
@@ -80,6 +80,7 @@ else:
|
||||
"AutoencoderTiny",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ControlNetModel",
|
||||
"ControlNetXSAdapter",
|
||||
"I2VGenXLUNet",
|
||||
"Kandinsky3UNet",
|
||||
"ModelMixin",
|
||||
@@ -94,6 +95,7 @@ else:
|
||||
"UNet2DConditionModel",
|
||||
"UNet2DModel",
|
||||
"UNet3DConditionModel",
|
||||
"UNetControlNetXSModel",
|
||||
"UNetMotionModel",
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"UVit2DModel",
|
||||
@@ -270,6 +272,7 @@ else:
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
@@ -293,6 +296,7 @@ else:
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
@@ -474,6 +478,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetModel,
|
||||
ControlNetXSAdapter,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
ModelMixin,
|
||||
@@ -487,6 +492,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
UNet3DConditionModel,
|
||||
UNetControlNetXSModel,
|
||||
UNetMotionModel,
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
@@ -642,6 +648,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
@@ -665,6 +672,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
|
||||
@@ -1267,6 +1267,10 @@ class LoraLoaderMixin:
|
||||
for adapter_name in adapter_names:
|
||||
unet_module.lora_A[adapter_name].to(device)
|
||||
unet_module.lora_B[adapter_name].to(device)
|
||||
# this is a param, not a module, so device placement is not in-place -> re-assign
|
||||
unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
|
||||
adapter_name
|
||||
].to(device)
|
||||
|
||||
# Handle the text encoder
|
||||
modules_to_process = []
|
||||
@@ -1283,6 +1287,10 @@ class LoraLoaderMixin:
|
||||
for adapter_name in adapter_names:
|
||||
text_encoder_module.lora_A[adapter_name].to(device)
|
||||
text_encoder_module.lora_B[adapter_name].to(device)
|
||||
# this is a param, not a module, so device placement is not in-place -> re-assign
|
||||
text_encoder_module.lora_magnitude_vector[
|
||||
adapter_name
|
||||
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
|
||||
|
||||
|
||||
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
|
||||
@@ -998,7 +998,7 @@ class FromOriginalUNetMixin:
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
@@ -68,6 +69,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsistencyDecoderVAE,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformers import (
|
||||
|
||||
@@ -102,6 +102,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
||||
act_fn: str = "relu",
|
||||
upsample_fn: str = "nearest",
|
||||
latent_channels: int = 4,
|
||||
upsampling_scaling_factor: int = 2,
|
||||
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
|
||||
@@ -133,6 +134,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
upsampling_scaling_factor=upsampling_scaling_factor,
|
||||
act_fn=act_fn,
|
||||
upsample_fn=upsample_fn,
|
||||
)
|
||||
|
||||
self.latent_magnitude = latent_magnitude
|
||||
|
||||
@@ -926,6 +926,7 @@ class DecoderTiny(nn.Module):
|
||||
block_out_channels: Tuple[int, ...],
|
||||
upsampling_scaling_factor: int,
|
||||
act_fn: str,
|
||||
upsample_fn: str,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -942,7 +943,7 @@ class DecoderTiny(nn.Module):
|
||||
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
|
||||
|
||||
if not is_final_block:
|
||||
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
|
||||
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn))
|
||||
|
||||
conv_out_channel = num_channels if not is_final_block else out_channels
|
||||
layers.append(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -699,6 +699,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
force_hooks=True,
|
||||
)
|
||||
except AttributeError as e:
|
||||
# When using accelerate loading, we do not have the ability to load the state
|
||||
|
||||
@@ -402,41 +402,18 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
batch_size, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
|
||||
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.is_input_patches and self.caption_projection is not None:
|
||||
batch_size = hidden_states.shape[0]
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -474,51 +451,116 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if self.is_input_patches:
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
output = self._get_output_for_continuous_inputs(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
inner_dim=inner_dim,
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
elif self.is_input_vectorized:
|
||||
output = self._get_output_for_vectorized_inputs(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
output = self._get_output_for_patched_inputs(
|
||||
hidden_states=hidden_states,
|
||||
timestep=timestep,
|
||||
class_labels=class_labels,
|
||||
embedded_timestep=embedded_timestep,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
def _operate_on_continuous_inputs(self, hidden_states):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
return hidden_states, inner_dim
|
||||
|
||||
def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
embedded_timestep = None
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
if self.caption_projection is not None:
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
return hidden_states, encoder_hidden_states, timestep, embedded_timestep
|
||||
|
||||
def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
|
||||
output = hidden_states + residual
|
||||
return output
|
||||
|
||||
def _get_output_for_vectorized_inputs(self, hidden_states):
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
return output
|
||||
|
||||
def _get_output_for_patched_inputs(
|
||||
self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
|
||||
):
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -746,6 +746,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
@@ -753,6 +754,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_groups_out: Optional[int] = None,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
@@ -764,6 +766,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
@@ -772,14 +778,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
resnet_groups_out = resnet_groups_out or resnet_groups
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
groups_out=resnet_groups_out,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
@@ -794,11 +803,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups_out,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
attention_type=attention_type,
|
||||
@@ -808,8 +817,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
@@ -817,11 +826,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
groups=resnet_groups_out,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
|
||||
@@ -134,6 +134,12 @@ else:
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["controlnet_xs"].extend(
|
||||
[
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["deepfloyd_if"] = [
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
@@ -378,6 +384,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .controlnet_xs import (
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
|
||||
@@ -898,6 +898,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -935,7 +941,12 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
self.vae.to(dtype)
|
||||
|
||||
init_latents = init_latents.to(dtype)
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=self.device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=self.device, dtype=dtype)
|
||||
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"]
|
||||
_import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
|
||||
from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
+159
-99
@@ -19,30 +19,75 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from controlnetxs import ControlNetXSModel
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # !pip install opencv-python transformers accelerate
|
||||
>>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
|
||||
>>> import cv2
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
>>> negative_prompt = "low quality, bad quality, sketches"
|
||||
|
||||
>>> # download an image
|
||||
>>> image = load_image(
|
||||
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
||||
... )
|
||||
|
||||
>>> # initialize the models and pipeline
|
||||
>>> controlnet_conditioning_scale = 0.5
|
||||
|
||||
>>> controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> # get canny image
|
||||
>>> image = np.array(image)
|
||||
>>> image = cv2.Canny(image, 100, 200)
|
||||
>>> image = image[:, :, None]
|
||||
>>> image = np.concatenate([image, image, image], axis=2)
|
||||
>>> canny_image = Image.fromarray(image)
|
||||
>>> # generate image
|
||||
>>> image = pipe(
|
||||
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
|
||||
... ).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionControlNetXSPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
@@ -56,7 +101,7 @@ class StableDiffusionControlNetXSPipeline(
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -66,9 +111,9 @@ class StableDiffusionControlNetXSPipeline(
|
||||
tokenizer ([`~transformers.CLIPTokenizer`]):
|
||||
A `CLIPTokenizer` to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded image latents.
|
||||
controlnet ([`ControlNetXSModel`]):
|
||||
Provides additional conditioning to the `unet` during the denoising process.
|
||||
A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
|
||||
controlnet ([`ControlNetXSAdapter`]):
|
||||
A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
@@ -80,17 +125,18 @@ class StableDiffusionControlNetXSPipeline(
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae>controlnet"
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetXSModel,
|
||||
unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
|
||||
controlnet: ControlNetXSAdapter,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
@@ -98,6 +144,9 @@ class StableDiffusionControlNetXSPipeline(
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(unet, UNet2DConditionModel):
|
||||
unet = UNetControlNetXSModel.from_unet(unet, controlnet)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
@@ -114,14 +163,6 @@ class StableDiffusionControlNetXSPipeline(
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
|
||||
vae
|
||||
)
|
||||
if not vae_compatible:
|
||||
raise ValueError(
|
||||
f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -403,20 +444,19 @@ class StableDiffusionControlNetXSPipeline(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
@@ -445,25 +485,16 @@ class StableDiffusionControlNetXSPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
# Check `image` and `controlnet_conditioning_scale`
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
self.unet, torch._dynamo.eval_frame.OptimizedModule
|
||||
)
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetXSModel)
|
||||
isinstance(self.unet, UNetControlNetXSModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
|
||||
and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetXSModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
else:
|
||||
@@ -563,7 +594,33 @@ class StableDiffusionControlNetXSPipeline(
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -581,13 +638,13 @@ class StableDiffusionControlNetXSPipeline(
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
control_guidance_start: float = 0.0,
|
||||
control_guidance_end: float = 1.0,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -595,7 +652,7 @@ class StableDiffusionControlNetXSPipeline(
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
||||
@@ -639,12 +696,6 @@ class StableDiffusionControlNetXSPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
@@ -659,7 +710,15 @@ class StableDiffusionControlNetXSPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
@@ -669,21 +728,27 @@ class StableDiffusionControlNetXSPipeline(
|
||||
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
||||
"not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -713,6 +778,7 @@ class StableDiffusionControlNetXSPipeline(
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
@@ -720,27 +786,24 @@ class StableDiffusionControlNetXSPipeline(
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetXSModel):
|
||||
image = self.prepare_image(
|
||||
image=image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
else:
|
||||
assert False
|
||||
image = self.prepare_image(
|
||||
image=image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=unet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -757,42 +820,33 @@ class StableDiffusionControlNetXSPipeline(
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
is_unet_compiled = is_compiled_module(self.unet)
|
||||
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
||||
self._num_timesteps = len(timesteps)
|
||||
is_controlnet_compiled = is_compiled_module(self.unet)
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# Relevant thread:
|
||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
||||
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
||||
if is_controlnet_compiled and is_torch_higher_equal_2_1:
|
||||
torch._inductor.cudagraph_mark_step_begin()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
dont_control = (
|
||||
i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
|
||||
apply_control = (
|
||||
i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
|
||||
)
|
||||
if dont_control:
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=True,
|
||||
).sample
|
||||
else:
|
||||
noise_pred = self.controlnet(
|
||||
base_model=self.unet,
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=True,
|
||||
).sample
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=True,
|
||||
apply_control=apply_control,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -801,12 +855,18 @@ class StableDiffusionControlNetXSPipeline(
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
+198
-89
@@ -19,41 +19,94 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> # !pip install opencv-python transformers accelerate
|
||||
>>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
|
||||
>>> import cv2
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
|
||||
>>> negative_prompt = "low quality, bad quality, sketches"
|
||||
|
||||
>>> # download an image
|
||||
>>> image = load_image(
|
||||
... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
||||
... )
|
||||
|
||||
>>> # initialize the models and pipeline
|
||||
>>> controlnet_conditioning_scale = 0.5
|
||||
>>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
||||
>>> controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> # get canny image
|
||||
>>> image = np.array(image)
|
||||
>>> image = cv2.Canny(image, 100, 200)
|
||||
>>> image = image[:, :, None]
|
||||
>>> image = np.concatenate([image, image, image], axis=2)
|
||||
>>> canny_image = Image.fromarray(image)
|
||||
|
||||
>>> # generate image
|
||||
>>> image = pipe(
|
||||
... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
|
||||
... ).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipeline(
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
@@ -66,9 +119,8 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
- [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
@@ -83,9 +135,9 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
tokenizer_2 ([`~transformers.CLIPTokenizer`]):
|
||||
A `CLIPTokenizer` to tokenize text.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded image latents.
|
||||
controlnet ([`ControlNetXSModel`]:
|
||||
Provides additional conditioning to the `unet` during the denoising process.
|
||||
A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
|
||||
controlnet ([`ControlNetXSAdapter`]):
|
||||
A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
@@ -98,9 +150,15 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
watermarker is used.
|
||||
"""
|
||||
|
||||
# leave controlnet out on purpose because it iterates with unet
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet"
|
||||
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = [
|
||||
"tokenizer",
|
||||
"tokenizer_2",
|
||||
"text_encoder",
|
||||
"text_encoder_2",
|
||||
"feature_extractor",
|
||||
]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -109,21 +167,17 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetXSModel,
|
||||
unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
|
||||
controlnet: ControlNetXSAdapter,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
|
||||
vae
|
||||
)
|
||||
if not vae_compatible:
|
||||
raise ValueError(
|
||||
f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
|
||||
)
|
||||
if isinstance(unet, UNet2DConditionModel):
|
||||
unet = UNetControlNetXSModel.from_unet(unet, controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -134,6 +188,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
||||
@@ -417,15 +472,21 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
controlnet_conditioning_scale=1.0,
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -474,25 +535,16 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
# Check `image` and ``controlnet_conditioning_scale``
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
self.unet, torch._dynamo.eval_frame.OptimizedModule
|
||||
)
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetXSModel)
|
||||
isinstance(self.unet, UNetControlNetXSModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
|
||||
and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
|
||||
):
|
||||
self.check_image(image, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
# Check `controlnet_conditioning_scale`
|
||||
if (
|
||||
isinstance(self.controlnet, ControlNetXSModel)
|
||||
or is_compiled
|
||||
and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
else:
|
||||
@@ -593,7 +645,6 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
@@ -602,7 +653,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
@@ -632,7 +683,33 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
|
||||
def cross_attention_kwargs(self):
|
||||
return self._cross_attention_kwargs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -654,8 +731,6 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
control_guidance_start: float = 0.0,
|
||||
@@ -667,6 +742,9 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
@@ -677,7 +755,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in both text-encoders.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
||||
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
||||
@@ -735,12 +813,6 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
@@ -783,6 +855,15 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -791,7 +872,24 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned containing the output images.
|
||||
"""
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
@@ -808,8 +906,14 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -850,7 +954,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetXSModel):
|
||||
if isinstance(unet, UNetControlNetXSModel):
|
||||
image = self.prepare_image(
|
||||
image=image,
|
||||
width=width,
|
||||
@@ -858,7 +962,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
dtype=unet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
@@ -870,7 +974,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -928,14 +1032,14 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
is_unet_compiled = is_compiled_module(self.unet)
|
||||
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
||||
self._num_timesteps = len(timesteps)
|
||||
is_controlnet_compiled = is_compiled_module(self.unet)
|
||||
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# Relevant thread:
|
||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
||||
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
||||
if is_controlnet_compiled and is_torch_higher_equal_2_1:
|
||||
torch._inductor.cudagraph_mark_step_begin()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@@ -944,30 +1048,20 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
# predict the noise residual
|
||||
dont_control = (
|
||||
i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
|
||||
apply_control = (
|
||||
i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
|
||||
)
|
||||
if dont_control:
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=True,
|
||||
).sample
|
||||
else:
|
||||
noise_pred = self.controlnet(
|
||||
base_model=self.unet,
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=True,
|
||||
).sample
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=True,
|
||||
apply_control=apply_control,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -977,6 +1071,16 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
@@ -984,6 +1088,11 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# manually for max memory savings
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
@@ -12,7 +12,6 @@ from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
@@ -115,6 +114,7 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
model_cpu_offload_seq = "text_encoder->unet"
|
||||
_exclude_from_cpu_offload = ["watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -156,20 +156,6 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_prompt(
|
||||
self,
|
||||
@@ -335,9 +321,6 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
|
||||
@@ -15,7 +15,6 @@ from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
@@ -139,6 +138,7 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
model_cpu_offload_seq = "text_encoder->unet"
|
||||
_exclude_from_cpu_offload = ["watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -180,21 +180,6 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_prompt(
|
||||
self,
|
||||
@@ -361,9 +346,6 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
|
||||
|
||||
@@ -16,7 +16,6 @@ from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
@@ -143,6 +142,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"]
|
||||
model_cpu_offload_seq = "text_encoder->unet"
|
||||
_exclude_from_cpu_offload = ["watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -191,21 +191,6 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
@@ -513,9 +498,6 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
|
||||
@@ -1012,8 +994,6 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
else:
|
||||
# 10. Post-processing
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
@@ -15,7 +15,6 @@ from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
@@ -142,6 +141,7 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
model_cpu_offload_seq = "text_encoder->unet"
|
||||
_exclude_from_cpu_offload = ["watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -183,21 +183,6 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
@torch.no_grad()
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
@@ -365,9 +350,6 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
|
||||
|
||||
@@ -16,7 +16,6 @@ from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
PIL_INTERPOLATION,
|
||||
is_accelerate_available,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
@@ -145,6 +144,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet"
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
_exclude_from_cpu_offload = ["watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -193,21 +193,6 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
@@ -515,9 +500,6 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
|
||||
|
||||
@@ -15,7 +15,6 @@ from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
is_bs4_available,
|
||||
is_ftfy_available,
|
||||
logging,
|
||||
@@ -101,6 +100,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
|
||||
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
|
||||
model_cpu_offload_seq = "text_encoder->unet"
|
||||
_exclude_from_cpu_offload = ["watermarker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -149,21 +149,6 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.safety_checker]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
if clean_caption and not is_bs4_available():
|
||||
@@ -471,9 +456,6 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
nsfw_detected = None
|
||||
watermark_detected = None
|
||||
|
||||
if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
|
||||
self.unet_offload_hook.offload()
|
||||
|
||||
return image, nsfw_detected, watermark_detected
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
|
||||
|
||||
@@ -2238,6 +2238,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
@@ -2245,6 +2246,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_groups_out: Optional[int] = None,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
@@ -2256,6 +2258,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
@@ -2264,14 +2270,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
resnet_groups_out = resnet_groups_out or resnet_groups
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
groups_out=resnet_groups_out,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
@@ -2286,11 +2295,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups_out,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
attention_type=attention_type,
|
||||
@@ -2300,8 +2309,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
@@ -2309,11 +2318,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
groups=resnet_groups_out,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
|
||||
@@ -143,6 +143,7 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
|
||||
|
||||
_load_connected_pipes = True
|
||||
model_cpu_offload_seq = "text_encoder->unet->movq->prior_prior->prior_image_encoder->prior_text_encoder"
|
||||
_exclude_from_cpu_offload = ["prior_prior"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -360,6 +361,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
|
||||
|
||||
_load_connected_pipes = True
|
||||
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
|
||||
_exclude_from_cpu_offload = ["prior_prior"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -600,6 +602,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
|
||||
|
||||
_load_connected_pipes = True
|
||||
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
|
||||
_exclude_from_cpu_offload = ["prior_prior"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -135,6 +135,7 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
|
||||
|
||||
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq"
|
||||
_load_connected_pipes = True
|
||||
_exclude_from_cpu_offload = ["prior_prior"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -362,6 +363,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
|
||||
|
||||
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq"
|
||||
_load_connected_pipes = True
|
||||
_exclude_from_cpu_offload = ["prior_prior"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -610,6 +612,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
|
||||
|
||||
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq"
|
||||
_load_connected_pipes = True
|
||||
_exclude_from_cpu_offload = ["prior_prior"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -8,7 +8,6 @@ from ...models import Kandinsky3UNet, VQModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
@@ -72,20 +71,6 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
|
||||
)
|
||||
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet, self.movq]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
def process_embeds(self, embeddings, attention_mask, cut_context):
|
||||
if cut_context:
|
||||
embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0])
|
||||
|
||||
@@ -12,7 +12,6 @@ from ...models import Kandinsky3UNet, VQModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
@@ -96,20 +95,6 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def remove_all_hooks(self):
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
for model in [self.text_encoder, self.unet]:
|
||||
if model is not None:
|
||||
remove_hook_from_module(model, recurse=True)
|
||||
|
||||
self.unet_offload_hook = None
|
||||
self.text_encoder_offload_hook = None
|
||||
self.final_offload_hook = None
|
||||
|
||||
def _process_embeds(self, embeddings, attention_mask, cut_context):
|
||||
# return embeddings, attention_mask
|
||||
if cut_context:
|
||||
|
||||
@@ -22,15 +22,19 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import (
|
||||
model_info,
|
||||
)
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
get_class_from_dynamic_module,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
@@ -44,9 +48,12 @@ if is_transformers_available():
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
from accelerate.utils import compute_module_sizes, get_max_memory
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
@@ -376,6 +383,207 @@ def _get_pipeline_class(
|
||||
return pipeline_cls
|
||||
|
||||
|
||||
def _load_empty_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
importable_classes: List[Any],
|
||||
pipelines: Any,
|
||||
is_pipeline_module: bool,
|
||||
name: str,
|
||||
torch_dtype: Union[str, torch.dtype],
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
**kwargs,
|
||||
):
|
||||
# retrieve class objects.
|
||||
class_obj, _ = get_class_obj_and_candidates(
|
||||
library_name,
|
||||
class_name,
|
||||
importable_classes,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
component_name=name,
|
||||
cache_dir=cached_folder,
|
||||
)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
# Determine library.
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
|
||||
model = None
|
||||
config_path = cached_folder
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if is_diffusers_model:
|
||||
# Load config and then the model on meta.
|
||||
config, unused_kwargs, commit_hash = class_obj.load_config(
|
||||
os.path.join(config_path, name),
|
||||
cache_dir=cached_folder,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=kwargs.pop("force_download", False),
|
||||
resume_download=kwargs.pop("resume_download", False),
|
||||
proxies=kwargs.pop("proxies", None),
|
||||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("token", None),
|
||||
revision=kwargs.pop("revision", None),
|
||||
subfolder=kwargs.pop("subfolder", None),
|
||||
user_agent=user_agent,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
model = class_obj.from_config(config, **unused_kwargs)
|
||||
elif is_transformers_model:
|
||||
config_class = getattr(class_obj, "config_class", None)
|
||||
if config_class is None:
|
||||
raise ValueError("`config_class` cannot be None. Please double-check the model.")
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
cached_folder,
|
||||
subfolder=name,
|
||||
force_download=kwargs.pop("force_download", False),
|
||||
resume_download=kwargs.pop("resume_download", False),
|
||||
proxies=kwargs.pop("proxies", None),
|
||||
local_files_only=kwargs.pop("local_files_only", False),
|
||||
token=kwargs.pop("token", None),
|
||||
revision=kwargs.pop("revision", None),
|
||||
user_agent=user_agent,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
model = class_obj(config)
|
||||
|
||||
if model is not None:
|
||||
model = model.to(dtype=torch_dtype)
|
||||
return model
|
||||
|
||||
|
||||
def _assign_components_to_devices(
|
||||
module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced"
|
||||
):
|
||||
device_ids = list(device_memory.keys())
|
||||
device_cycle = device_ids + device_ids[::-1]
|
||||
device_memory = device_memory.copy()
|
||||
|
||||
device_id_component_mapping = {}
|
||||
current_device_index = 0
|
||||
for component in module_sizes:
|
||||
device_id = device_cycle[current_device_index % len(device_cycle)]
|
||||
component_memory = module_sizes[component]
|
||||
curr_device_memory = device_memory[device_id]
|
||||
|
||||
# If the GPU doesn't fit the current component offload to the CPU.
|
||||
if component_memory > curr_device_memory:
|
||||
device_id_component_mapping["cpu"] = [component]
|
||||
else:
|
||||
if device_id not in device_id_component_mapping:
|
||||
device_id_component_mapping[device_id] = [component]
|
||||
else:
|
||||
device_id_component_mapping[device_id].append(component)
|
||||
|
||||
# Update the device memory.
|
||||
device_memory[device_id] -= component_memory
|
||||
current_device_index += 1
|
||||
|
||||
return device_id_component_mapping
|
||||
|
||||
|
||||
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
|
||||
# To avoid circular import problem.
|
||||
from diffusers import pipelines
|
||||
|
||||
torch_dtype = kwargs.get("torch_dtype", torch.float32)
|
||||
|
||||
# Load each module in the pipeline on a meta device so that we can derive the device map.
|
||||
init_empty_modules = {}
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name.startswith("Flax"):
|
||||
raise ValueError("Flax pipelines are not supported with `device_map`.")
|
||||
|
||||
# Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# Use passed sub model or load class_name from library_name
|
||||
if name in passed_class_obj:
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
maybe_raise_or_warn(
|
||||
library_name,
|
||||
library,
|
||||
class_name,
|
||||
importable_classes,
|
||||
passed_class_obj,
|
||||
name,
|
||||
is_pipeline_module,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
|
||||
else:
|
||||
loaded_sub_model = _load_empty_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
importable_classes=importable_classes,
|
||||
pipelines=pipelines,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
pipeline_class=pipeline_class,
|
||||
name=name,
|
||||
torch_dtype=torch_dtype,
|
||||
cached_folder=kwargs.get("cached_folder", None),
|
||||
force_download=kwargs.get("force_download", None),
|
||||
resume_download=kwargs.get("resume_download", None),
|
||||
proxies=kwargs.get("proxies", None),
|
||||
local_files_only=kwargs.get("local_files_only", None),
|
||||
token=kwargs.get("token", None),
|
||||
revision=kwargs.get("revision", None),
|
||||
)
|
||||
|
||||
if loaded_sub_model is not None:
|
||||
init_empty_modules[name] = loaded_sub_model
|
||||
|
||||
# determine device map
|
||||
# Obtain a sorted dictionary for mapping the model-level components
|
||||
# to their sizes.
|
||||
module_sizes = {
|
||||
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
|
||||
for module_name, module in init_empty_modules.items()
|
||||
if isinstance(module, torch.nn.Module)
|
||||
}
|
||||
module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
|
||||
|
||||
# Obtain maximum memory available per device (GPUs only).
|
||||
max_memory = get_max_memory(max_memory)
|
||||
max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
|
||||
max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
|
||||
|
||||
# Obtain a dictionary mapping the model-level components to the available
|
||||
# devices based on the maximum memory and the model sizes.
|
||||
device_id_component_mapping = _assign_components_to_devices(
|
||||
module_sizes, max_memory, device_mapping_strategy=device_map
|
||||
)
|
||||
|
||||
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
|
||||
final_device_map = {}
|
||||
for device_id, components in device_id_component_mapping.items():
|
||||
for component in components:
|
||||
final_device_map[component] = device_id
|
||||
|
||||
return final_device_map
|
||||
|
||||
|
||||
def load_sub_model(
|
||||
library_name: str,
|
||||
class_name: str,
|
||||
@@ -493,6 +701,22 @@ def load_sub_model(
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||
|
||||
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
|
||||
# remove hooks
|
||||
remove_hook_from_module(loaded_sub_model, recurse=True)
|
||||
needs_offloading_to_cpu = device_map[""] == "cpu"
|
||||
|
||||
if needs_offloading_to_cpu:
|
||||
dispatch_model(
|
||||
loaded_sub_model,
|
||||
state_dict=loaded_sub_model.state_dict(),
|
||||
device_map=device_map,
|
||||
force_hooks=True,
|
||||
main_device=0,
|
||||
)
|
||||
else:
|
||||
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
|
||||
@@ -73,6 +73,7 @@ from .pipeline_loading_utils import (
|
||||
LOADABLE_CLASSES,
|
||||
_fetch_class_library_tuple,
|
||||
_get_custom_pipeline_class,
|
||||
_get_final_device_map,
|
||||
_get_pipeline_class,
|
||||
_unwrap_model,
|
||||
is_safetensors_compatible,
|
||||
@@ -91,6 +92,8 @@ LIBRARIES = []
|
||||
for library in LOADABLE_CLASSES:
|
||||
LIBRARIES.append(library)
|
||||
|
||||
SUPPORTED_DEVICE_MAP = ["balanced"]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -141,6 +144,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
config_name = "model_index.json"
|
||||
model_cpu_offload_seq = None
|
||||
hf_device_map = None
|
||||
_optional_components = []
|
||||
_exclude_from_cpu_offload = []
|
||||
_load_connected_pipes = False
|
||||
@@ -389,6 +393,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
||||
)
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
|
||||
)
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
|
||||
@@ -642,18 +652,35 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
if device_map is not None and not is_accelerate_available():
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
"Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
|
||||
)
|
||||
|
||||
if device_map is not None and not isinstance(device_map, str):
|
||||
raise ValueError("`device_map` must be a string.")
|
||||
|
||||
if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
|
||||
raise NotImplementedError(
|
||||
f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
|
||||
)
|
||||
|
||||
if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
|
||||
if is_accelerate_version("<", "0.28.0"):
|
||||
raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||
@@ -729,6 +756,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision=custom_revision,
|
||||
)
|
||||
|
||||
if device_map is not None and pipeline_class._load_connected_pipes:
|
||||
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
|
||||
|
||||
# DEPRECATED: To be removed in 1.0.0
|
||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||
version.parse(config_dict["_diffusers_version"]).base_version
|
||||
@@ -795,17 +825,45 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 6. Load each module in the pipeline
|
||||
# 6. device map delegation
|
||||
final_device_map = None
|
||||
if device_map is not None:
|
||||
final_device_map = _get_final_device_map(
|
||||
device_map=device_map,
|
||||
pipeline_class=pipeline_class,
|
||||
passed_class_obj=passed_class_obj,
|
||||
init_dict=init_dict,
|
||||
library=library,
|
||||
max_memory=max_memory,
|
||||
torch_dtype=torch_dtype,
|
||||
cached_folder=cached_folder,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# 7. Load each module in the pipeline
|
||||
current_device_map = None
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if final_device_map is not None and len(final_device_map) > 0:
|
||||
component_device = final_device_map.get(name, None)
|
||||
if component_device is not None:
|
||||
current_device_map = {"": component_device}
|
||||
else:
|
||||
current_device_map = None
|
||||
|
||||
# 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
# 6.2 Define all importable classes
|
||||
# 7.2 Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# 6.3 Use passed sub model or load class_name from library_name
|
||||
# 7.3 Use passed sub model or load class_name from library_name
|
||||
if name in passed_class_obj:
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
@@ -826,7 +884,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
torch_dtype=torch_dtype,
|
||||
provider=provider,
|
||||
sess_options=sess_options,
|
||||
device_map=device_map,
|
||||
device_map=current_device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
@@ -893,7 +951,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
||||
)
|
||||
|
||||
# 7. Potentially add passed objects if expected
|
||||
# 8. Potentially add passed objects if expected
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
@@ -906,11 +964,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# 8. Instantiate the pipeline
|
||||
# 10. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
|
||||
# 9. Save where the model was instantiated from
|
||||
# 11. Save where the model was instantiated from
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
if device_map is not None:
|
||||
setattr(model, "hf_device_map", final_device_map)
|
||||
return model
|
||||
|
||||
@property
|
||||
@@ -963,6 +1023,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
||||
default to "cuda".
|
||||
"""
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
||||
)
|
||||
|
||||
if self.model_cpu_offload_seq is None:
|
||||
raise ValueError(
|
||||
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
|
||||
@@ -1056,6 +1122,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
self.remove_all_hooks()
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
||||
)
|
||||
|
||||
torch_device = torch.device(device)
|
||||
device_index = torch_device.index
|
||||
|
||||
@@ -1090,6 +1162,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
offload_buffers = len(model._parameters) > 0
|
||||
cpu_offload(model, device, offload_buffers=offload_buffers)
|
||||
|
||||
def reset_device_map(self):
|
||||
r"""
|
||||
Resets the device maps (if any) to None.
|
||||
"""
|
||||
if self.hf_device_map is None:
|
||||
return
|
||||
else:
|
||||
self.remove_all_hooks()
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
component.to("cpu")
|
||||
self.hf_device_map = None
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
||||
@@ -1731,7 +1816,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
):
|
||||
original_class_obj[name] = component
|
||||
else:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"component {name} is not switched over to new pipeline because type does not match the expected."
|
||||
f" {name} is {type(component)} while the new pipeline expect {component_types[name]}."
|
||||
f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`"
|
||||
|
||||
@@ -1843,6 +1843,8 @@ def download_controlnet_from_original_ckpt(
|
||||
while "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
with open(original_config_file, "r") as f:
|
||||
original_config_file = f.read()
|
||||
original_config = yaml.safe_load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
|
||||
+12
-1
@@ -665,6 +665,12 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -702,7 +708,12 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
self.vae.to(dtype)
|
||||
|
||||
init_latents = init_latents.to(dtype)
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
if latents_mean is not None and latents_std is not None:
|
||||
latents_mean = latents_mean.to(device=self.device, dtype=dtype)
|
||||
latents_std = latents_std.to(device=self.device, dtype=dtype)
|
||||
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
|
||||
else:
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
|
||||
+7
@@ -169,6 +169,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
|
||||
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
|
||||
watermarker will be used.
|
||||
is_cosxl_edit (`bool`, *optional*):
|
||||
When set the image latents are scaled.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
@@ -185,6 +187,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
is_cosxl_edit: Optional[bool] = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -201,6 +204,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
self.is_cosxl_edit = is_cosxl_edit
|
||||
|
||||
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
||||
|
||||
@@ -551,6 +555,9 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
||||
if image_latents.dtype != self.vae.dtype:
|
||||
image_latents = image_latents.to(dtype=self.vae.dtype)
|
||||
|
||||
if self.is_cosxl_edit:
|
||||
image_latents = image_latents * self.vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
|
||||
|
||||
@@ -92,6 +92,21 @@ class ControlNetModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ControlNetXSAdapter(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class I2VGenXLUNet(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -287,6 +302,21 @@ class UNet3DConditionModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNetControlNetXSModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNetMotionModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -902,6 +902,21 @@ class StableDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1247,6 +1262,21 @@ class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -64,9 +64,11 @@ def recurse_remove_peft_layers(model):
|
||||
module_replaced = False
|
||||
|
||||
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
||||
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
|
||||
module.weight.device
|
||||
)
|
||||
new_module = torch.nn.Linear(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
bias=module.bias is not None,
|
||||
).to(module.weight.device)
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
@@ -110,6 +112,9 @@ def scale_lora_layers(model, weight):
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if weight == 1.0:
|
||||
return
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.scale_layer(weight)
|
||||
@@ -129,6 +134,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if weight == 1.0:
|
||||
return
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if weight is not None and weight != 0:
|
||||
|
||||
@@ -255,6 +255,20 @@ def require_torch_accelerator(test_case):
|
||||
)
|
||||
|
||||
|
||||
def require_torch_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
|
||||
multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests
|
||||
-k "multi_gpu"
|
||||
"""
|
||||
if not is_torch_available():
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_torch_accelerator_with_fp16(test_case):
|
||||
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
||||
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
||||
@@ -343,6 +357,18 @@ def require_peft_version_greater(peft_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_accelerate_version_greater(accelerate_version):
|
||||
def decorator(test_case):
|
||||
correct_accelerate_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("accelerate")).base_version
|
||||
) > version.parse(accelerate_version)
|
||||
return unittest.skipUnless(
|
||||
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
@@ -150,6 +150,54 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_integration_move_lora_dora_cpu(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
path = "runwayml/stable-diffusion-v1-5"
|
||||
unet_lora_config = LoraConfig(
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
use_dora=True,
|
||||
)
|
||||
text_lora_config = LoraConfig(
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
use_dora=True,
|
||||
)
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
|
||||
for name, param in pipe.unet.named_parameters():
|
||||
if "lora_" in name:
|
||||
self.assertEqual(param.device, torch.device("cpu"))
|
||||
|
||||
for name, param in pipe.text_encoder.named_parameters():
|
||||
if "lora_" in name:
|
||||
self.assertEqual(param.device, torch.device("cpu"))
|
||||
|
||||
pipe.set_lora_device(["adapter-1"], torch_device)
|
||||
|
||||
for name, param in pipe.unet.named_parameters():
|
||||
if "lora_" in name:
|
||||
self.assertNotEqual(param.device, torch.device("cpu"))
|
||||
|
||||
for name, param in pipe.text_encoder.named_parameters():
|
||||
if "lora_" in name:
|
||||
self.assertNotEqual(param.device, torch.device("cpu"))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetControlNetXSModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
|
||||
conditioning_scale = 1
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"controlnet_cond": controlnet_cond,
|
||||
"conditioning_scale": conditioning_scale,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 16,
|
||||
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
"block_out_channels": (4, 8),
|
||||
"cross_attention_dim": 8,
|
||||
"transformer_layers_per_block": 1,
|
||||
"num_attention_heads": 2,
|
||||
"norm_num_groups": 4,
|
||||
"upcast_attention": False,
|
||||
"ctrl_block_out_channels": [2, 4],
|
||||
"ctrl_num_attention_heads": 4,
|
||||
"ctrl_max_norm_num_groups": 2,
|
||||
"ctrl_conditioning_embedding_out_channels": (2, 2),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_unet(self):
|
||||
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
return UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
sample_size=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=8,
|
||||
norm_num_groups=4,
|
||||
use_linear_projection=True,
|
||||
)
|
||||
|
||||
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
|
||||
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
|
||||
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
|
||||
|
||||
def test_from_unet(self):
|
||||
unet = self.get_dummy_unet()
|
||||
controlnet = self.get_dummy_controlnet_from_unet(unet)
|
||||
|
||||
model = UNetControlNetXSModel.from_unet(unet, controlnet)
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
def assert_equal_weights(module, weight_dict_prefix):
|
||||
for param_name, param_value in module.named_parameters():
|
||||
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
|
||||
|
||||
# # check unet
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
"time_embedding",
|
||||
"conv_in",
|
||||
"conv_norm_out",
|
||||
"conv_out",
|
||||
]
|
||||
for p in modules_from_unet:
|
||||
assert_equal_weights(getattr(unet, p), "base_" + p)
|
||||
optional_modules_from_unet = [
|
||||
"class_embedding",
|
||||
"add_time_proj",
|
||||
"add_embedding",
|
||||
]
|
||||
for p in optional_modules_from_unet:
|
||||
if hasattr(unet, p) and getattr(unet, p) is not None:
|
||||
assert_equal_weights(getattr(unet, p), "base_" + p)
|
||||
# down blocks
|
||||
assert len(unet.down_blocks) == len(model.down_blocks)
|
||||
for i, d in enumerate(unet.down_blocks):
|
||||
assert_equal_weights(d.resnets, f"down_blocks.{i}.base_resnets")
|
||||
if hasattr(d, "attentions"):
|
||||
assert_equal_weights(d.attentions, f"down_blocks.{i}.base_attentions")
|
||||
if hasattr(d, "downsamplers") and getattr(d, "downsamplers") is not None:
|
||||
assert_equal_weights(d.downsamplers[0], f"down_blocks.{i}.base_downsamplers")
|
||||
# mid block
|
||||
assert_equal_weights(unet.mid_block, "mid_block.base_midblock")
|
||||
# up blocks
|
||||
assert len(unet.up_blocks) == len(model.up_blocks)
|
||||
for i, u in enumerate(unet.up_blocks):
|
||||
assert_equal_weights(u.resnets, f"up_blocks.{i}.resnets")
|
||||
if hasattr(u, "attentions"):
|
||||
assert_equal_weights(u.attentions, f"up_blocks.{i}.attentions")
|
||||
if hasattr(u, "upsamplers") and getattr(u, "upsamplers") is not None:
|
||||
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
|
||||
|
||||
# # check controlnet
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_controlnet = {
|
||||
"controlnet_cond_embedding": "controlnet_cond_embedding",
|
||||
"conv_in": "ctrl_conv_in",
|
||||
"control_to_base_for_conv_in": "control_to_base_for_conv_in",
|
||||
}
|
||||
optional_modules_from_controlnet = {"time_embedding": "ctrl_time_embedding"}
|
||||
for name_in_controlnet, name_in_unetcnxs in modules_from_controlnet.items():
|
||||
assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
|
||||
|
||||
for name_in_controlnet, name_in_unetcnxs in optional_modules_from_controlnet.items():
|
||||
if hasattr(controlnet, name_in_controlnet) and getattr(controlnet, name_in_controlnet) is not None:
|
||||
assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
|
||||
# down blocks
|
||||
assert len(controlnet.down_blocks) == len(model.down_blocks)
|
||||
for i, d in enumerate(controlnet.down_blocks):
|
||||
assert_equal_weights(d.resnets, f"down_blocks.{i}.ctrl_resnets")
|
||||
assert_equal_weights(d.base_to_ctrl, f"down_blocks.{i}.base_to_ctrl")
|
||||
assert_equal_weights(d.ctrl_to_base, f"down_blocks.{i}.ctrl_to_base")
|
||||
if d.attentions is not None:
|
||||
assert_equal_weights(d.attentions, f"down_blocks.{i}.ctrl_attentions")
|
||||
if d.downsamplers is not None:
|
||||
assert_equal_weights(d.downsamplers, f"down_blocks.{i}.ctrl_downsamplers")
|
||||
# mid block
|
||||
assert_equal_weights(controlnet.mid_block.base_to_ctrl, "mid_block.base_to_ctrl")
|
||||
assert_equal_weights(controlnet.mid_block.midblock, "mid_block.ctrl_midblock")
|
||||
assert_equal_weights(controlnet.mid_block.ctrl_to_base, "mid_block.ctrl_to_base")
|
||||
# up blocks
|
||||
assert len(controlnet.up_connections) == len(model.up_blocks)
|
||||
for i, u in enumerate(controlnet.up_connections):
|
||||
assert_equal_weights(u.ctrl_to_base, f"up_blocks.{i}.ctrl_to_base")
|
||||
|
||||
def test_freeze_unet(self):
|
||||
def assert_frozen(module):
|
||||
for p in module.parameters():
|
||||
assert not p.requires_grad
|
||||
|
||||
def assert_unfrozen(module):
|
||||
for p in module.parameters():
|
||||
assert p.requires_grad
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = UNetControlNetXSModel(**init_dict)
|
||||
model.freeze_unet_params()
|
||||
|
||||
# # check unet
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
model.base_time_embedding,
|
||||
model.base_conv_in,
|
||||
model.base_conv_norm_out,
|
||||
model.base_conv_out,
|
||||
]
|
||||
for m in modules_from_unet:
|
||||
assert_frozen(m)
|
||||
|
||||
optional_modules_from_unet = [
|
||||
model.base_add_time_proj,
|
||||
model.base_add_embedding,
|
||||
]
|
||||
for m in optional_modules_from_unet:
|
||||
if m is not None:
|
||||
assert_frozen(m)
|
||||
|
||||
# down blocks
|
||||
for i, d in enumerate(model.down_blocks):
|
||||
assert_frozen(d.base_resnets)
|
||||
if isinstance(d.base_attentions, nn.ModuleList): # attentions can be list of Nones
|
||||
assert_frozen(d.base_attentions)
|
||||
if d.base_downsamplers is not None:
|
||||
assert_frozen(d.base_downsamplers)
|
||||
|
||||
# mid block
|
||||
assert_frozen(model.mid_block.base_midblock)
|
||||
|
||||
# up blocks
|
||||
for i, u in enumerate(model.up_blocks):
|
||||
assert_frozen(u.resnets)
|
||||
if isinstance(u.attentions, nn.ModuleList): # attentions can be list of Nones
|
||||
assert_frozen(u.attentions)
|
||||
if u.upsamplers is not None:
|
||||
assert_frozen(u.upsamplers)
|
||||
|
||||
# # check controlnet
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_controlnet = [
|
||||
model.controlnet_cond_embedding,
|
||||
model.ctrl_conv_in,
|
||||
model.control_to_base_for_conv_in,
|
||||
]
|
||||
optional_modules_from_controlnet = [model.ctrl_time_embedding]
|
||||
|
||||
for m in modules_from_controlnet:
|
||||
assert_unfrozen(m)
|
||||
for m in optional_modules_from_controlnet:
|
||||
if m is not None:
|
||||
assert_unfrozen(m)
|
||||
|
||||
# down blocks
|
||||
for d in model.down_blocks:
|
||||
assert_unfrozen(d.ctrl_resnets)
|
||||
assert_unfrozen(d.base_to_ctrl)
|
||||
assert_unfrozen(d.ctrl_to_base)
|
||||
if isinstance(d.ctrl_attentions, nn.ModuleList): # attentions can be list of Nones
|
||||
assert_unfrozen(d.ctrl_attentions)
|
||||
if d.ctrl_downsamplers is not None:
|
||||
assert_unfrozen(d.ctrl_downsamplers)
|
||||
# mid block
|
||||
assert_unfrozen(model.mid_block.base_to_ctrl)
|
||||
assert_unfrozen(model.mid_block.ctrl_midblock)
|
||||
assert_unfrozen(model.mid_block.ctrl_to_base)
|
||||
# up blocks
|
||||
for u in model.up_blocks:
|
||||
assert_unfrozen(u.ctrl_to_base)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
model_class_copy = copy.copy(UNetControlNetXSModel)
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
|
||||
# now monkey patch the following function:
|
||||
# def _set_gradient_checkpointing(self, module, value=False):
|
||||
# if hasattr(module, "gradient_checkpointing"):
|
||||
# module.gradient_checkpointing = value
|
||||
|
||||
def _set_gradient_checkpointing_new(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
modules_with_gc_enabled[module.__class__.__name__] = True
|
||||
|
||||
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = model_class_copy(**init_dict)
|
||||
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
EXPECTED_SET = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||
|
||||
def test_forward_no_control(self):
|
||||
unet = self.get_dummy_unet()
|
||||
controlnet = self.get_dummy_controlnet_from_unet(unet)
|
||||
|
||||
model = UNetControlNetXSModel.from_unet(unet, controlnet)
|
||||
|
||||
unet = unet.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ = self.dummy_input
|
||||
|
||||
control_specific_input = ["controlnet_cond", "conditioning_scale"]
|
||||
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
|
||||
|
||||
with torch.no_grad():
|
||||
unet_output = unet(**input_for_unet).sample.cpu()
|
||||
unet_controlnet_output = model(**input_, apply_control=False).sample.cpu()
|
||||
|
||||
assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4
|
||||
|
||||
def test_time_embedding_mixing(self):
|
||||
unet = self.get_dummy_unet()
|
||||
controlnet = self.get_dummy_controlnet_from_unet(unet)
|
||||
controlnet_mix_time = self.get_dummy_controlnet_from_unet(
|
||||
unet, time_embedding_mix=0.5, learn_time_embedding=True
|
||||
)
|
||||
|
||||
model = UNetControlNetXSModel.from_unet(unet, controlnet)
|
||||
model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time)
|
||||
|
||||
unet = unet.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
model_mix_time = model_mix_time.to(torch_device)
|
||||
|
||||
input_ = self.dummy_input
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**input_).sample
|
||||
output_mix_time = model_mix_time(**input_).sample
|
||||
|
||||
assert output.shape == output_mix_time.shape
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
|
||||
pass
|
||||
@@ -0,0 +1,366 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetXSAdapter,
|
||||
DDIMScheduler,
|
||||
LCMScheduler,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
load_image,
|
||||
load_numpy,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...models.autoencoders.test_models_vae import (
|
||||
get_asym_autoencoder_kl_config,
|
||||
get_autoencoder_kl_config,
|
||||
get_autoencoder_tiny_config,
|
||||
get_consistency_vae_config,
|
||||
)
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDFunctionTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
try:
|
||||
_ = in_queue.get(timeout=timeout)
|
||||
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base",
|
||||
controlnet=controlnet,
|
||||
safety_checker=None,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
).resize((512, 512))
|
||||
|
||||
output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
|
||||
)
|
||||
expected_image = np.resize(expected_image, (512, 512, 3))
|
||||
|
||||
assert np.abs(expected_image - image).max() < 1.0
|
||||
|
||||
except Exception:
|
||||
error = f"{traceback.format_exc()}"
|
||||
|
||||
results = {"error": error}
|
||||
out_queue.put(results, timeout=timeout)
|
||||
out_queue.join()
|
||||
|
||||
|
||||
class ControlNetXSPipelineFastTests(
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDFunctionTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionControlNetXSPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
test_attention_slicing = False
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
sample_size=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=8,
|
||||
norm_num_groups=4,
|
||||
time_cond_proj_dim=time_cond_proj_dim,
|
||||
use_linear_projection=True,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetXSAdapter.from_unet(
|
||||
unet=unet,
|
||||
size_ratio=1,
|
||||
learn_time_embedding=True,
|
||||
conditioning_embedding_out_channels=(2, 2),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[4, 8],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=8,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = randn_tensor(
|
||||
(1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"image": image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_controlnet_lcm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
components = self.get_dummy_components(time_cond_proj_dim=8)
|
||||
sd_pipe = StableDiffusionControlNetXSPipeline(**components)
|
||||
sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = sd_pipe(**inputs)
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 16, 16, 3)
|
||||
expected_slice = np.array([0.745, 0.753, 0.767, 0.543, 0.523, 0.502, 0.314, 0.521, 0.478])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_multi_vae(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
block_out_channels = pipe.vae.config.block_out_channels
|
||||
norm_num_groups = pipe.vae.config.norm_num_groups
|
||||
|
||||
vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
|
||||
configs = [
|
||||
get_autoencoder_kl_config(block_out_channels, norm_num_groups),
|
||||
get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
|
||||
get_consistency_vae_config(block_out_channels, norm_num_groups),
|
||||
get_autoencoder_tiny_config(block_out_channels),
|
||||
]
|
||||
|
||||
out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
|
||||
for vae_cls, config in zip(vae_classes, configs):
|
||||
vae = vae_cls(**config)
|
||||
vae = vae.to(torch_device)
|
||||
components["vae"] = vae
|
||||
vae_pipe = self.pipeline_class(**components)
|
||||
|
||||
# pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
|
||||
# So we need to move the new pipe to device.
|
||||
vae_pipe.to(torch_device)
|
||||
vae_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
|
||||
assert out_vae_np.shape == out_np.shape
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class ControlNetXSPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (768, 512, 3)
|
||||
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.1963, 0.229, 0.2659, 0.2109, 0.2332, 0.2827, 0.2534, 0.2422, 0.2808])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
def test_depth(self):
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "Stormtrooper's lecture"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
|
||||
)
|
||||
|
||||
output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
original_image = image[-3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
@require_python39_or_higher
|
||||
@require_torch_2
|
||||
def test_stable_diffusion_compile(self):
|
||||
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
|
||||
@@ -0,0 +1,425 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetXSAdapter,
|
||||
EulerDiscreteScheduler,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...models.autoencoders.test_models_vae import (
|
||||
get_asym_autoencoder_kl_config,
|
||||
get_autoencoder_kl_config,
|
||||
get_autoencoder_tiny_config,
|
||||
get_consistency_vae_config,
|
||||
)
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_TO_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipelineFastTests(
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineKarrasSchedulerTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
SDXLOptionalComponentsTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = StableDiffusionXLControlNetXSPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
test_attention_slicing = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
sample_size=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
use_linear_projection=True,
|
||||
norm_num_groups=4,
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=56, # 6 * 8 (addition_time_embed_dim) + 8 (cross_attention_dim)
|
||||
cross_attention_dim=8,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetXSAdapter.from_unet(
|
||||
unet=unet,
|
||||
size_ratio=0.5,
|
||||
learn_time_embedding=True,
|
||||
conditioning_embedding_out_channels=(2, 2),
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[4, 8],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=4,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=8,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
controlnet_embedder_scale_factor = 2
|
||||
image = randn_tensor(
|
||||
(1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
|
||||
generator=generator,
|
||||
device=torch.device(device),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
"image": image,
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_xl_offloads(self):
|
||||
pipes = []
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_model_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe.enable_sequential_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
pipe.unet.set_default_attn_processor()
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
|
||||
|
||||
# copied from test_controlnet_sdxl.py
|
||||
def test_stable_diffusion_xl_multi_prompts(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
|
||||
# forward with single prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with same prompt duplicated
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = inputs["prompt"]
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_2 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
# forward with different prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "different prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_3 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are not equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
# manually set a negative_prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with same negative_prompt duplicated
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
inputs["negative_prompt_2"] = inputs["negative_prompt"]
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_2 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
# forward with different negative_prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
inputs["negative_prompt_2"] = "different negative prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_3 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# ensure the results are not equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
# copied from test_stable_diffusion_xl.py
|
||||
def test_stable_diffusion_xl_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward without prompt embeds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = 2 * [inputs["prompt"]]
|
||||
inputs["num_images_per_prompt"] = 2
|
||||
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with prompt embeds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = 2 * [inputs.pop("prompt")]
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = sd_pipe.encode_prompt(prompt)
|
||||
|
||||
output = sd_pipe(
|
||||
**inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
image_slice_2 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# make sure that it's equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4
|
||||
|
||||
# copied from test_stable_diffusion_xl.py
|
||||
def test_save_load_optional_components(self):
|
||||
self._test_save_load_optional_components()
|
||||
|
||||
# copied from test_controlnetxs.py
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_multi_vae(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
block_out_channels = pipe.vae.config.block_out_channels
|
||||
norm_num_groups = pipe.vae.config.norm_num_groups
|
||||
|
||||
vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
|
||||
configs = [
|
||||
get_autoencoder_kl_config(block_out_channels, norm_num_groups),
|
||||
get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
|
||||
get_consistency_vae_config(block_out_channels, norm_num_groups),
|
||||
get_autoencoder_tiny_config(block_out_channels),
|
||||
]
|
||||
|
||||
out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
|
||||
for vae_cls, config in zip(vae_classes, configs):
|
||||
vae = vae_cls(**config)
|
||||
vae = vae.to(torch_device)
|
||||
components["vae"] = vae
|
||||
vae_pipe = self.pipeline_class(**components)
|
||||
|
||||
# pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
|
||||
# So we need to move the new pipe to device.
|
||||
vae_pipe.to(torch_device)
|
||||
vae_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
|
||||
assert out_vae_np.shape == out_np.shape
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_canny(self):
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "bird"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
|
||||
)
|
||||
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
|
||||
|
||||
assert images[0].shape == (768, 512, 3)
|
||||
|
||||
original_image = images[0, -3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.3202, 0.3151, 0.3328, 0.3172, 0.337, 0.3381, 0.3378, 0.3389, 0.3224])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
|
||||
def test_depth(self):
|
||||
controlnet = ControlNetXSAdapter.from_pretrained(
|
||||
"UmerHA/Testing-ConrolNetXS-SDXL-depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "Stormtrooper's lecture"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
|
||||
)
|
||||
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
|
||||
|
||||
assert images[0].shape == (512, 512, 3)
|
||||
|
||||
original_image = images[0, -3:, -3:, -1].flatten()
|
||||
expected_image = np.array([0.5448, 0.5437, 0.5426, 0.5543, 0.553, 0.5475, 0.5595, 0.5602, 0.5529])
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
@@ -299,7 +299,7 @@ class KandinskyPipelineIntegrationTests(unittest.TestCase):
|
||||
pipe_prior.to(torch_device)
|
||||
|
||||
pipeline = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline.to(torch_device)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "red cat, 4k photo"
|
||||
|
||||
@@ -25,11 +25,12 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -248,12 +249,12 @@ class KandinskyV22PipelineIntegrationTests(unittest.TestCase):
|
||||
pipeline = KandinskyV22Pipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "red cat, 4k photo"
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_emb, zero_image_emb = pipe_prior(
|
||||
prompt,
|
||||
generator=generator,
|
||||
@@ -261,7 +262,7 @@ class KandinskyV22PipelineIntegrationTests(unittest.TestCase):
|
||||
negative_prompt="",
|
||||
).to_tuple()
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipeline(
|
||||
image_embeds=image_emb,
|
||||
negative_image_embeds=zero_image_emb,
|
||||
@@ -269,9 +270,8 @@ class KandinskyV22PipelineIntegrationTests(unittest.TestCase):
|
||||
num_inference_steps=3,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten())
|
||||
assert max_diff < 1e-4
|
||||
|
||||
@@ -33,10 +33,11 @@ from diffusers.utils.testing_utils import (
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -260,12 +261,12 @@ class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase):
|
||||
pipeline = KandinskyV22ControlnetPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A robot, 4k photo"
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_emb, zero_image_emb = pipe_prior(
|
||||
prompt,
|
||||
generator=generator,
|
||||
@@ -273,7 +274,7 @@ class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase):
|
||||
negative_prompt="",
|
||||
).to_tuple()
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipeline(
|
||||
image_embeds=image_emb,
|
||||
negative_image_embeds=zero_image_emb,
|
||||
@@ -287,4 +288,5 @@ class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten())
|
||||
assert max_diff < 1e-4
|
||||
|
||||
@@ -34,10 +34,11 @@ from diffusers.utils.testing_utils import (
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -274,7 +275,7 @@ class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
pipeline = KandinskyV22ControlnetImg2ImgPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-controlnet-depth", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -289,6 +290,7 @@ class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
num_inference_steps=5,
|
||||
).to_tuple()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipeline(
|
||||
image=init_image,
|
||||
image_embeds=image_emb,
|
||||
@@ -306,4 +308,5 @@ class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten())
|
||||
assert max_diff < 1e-4
|
||||
|
||||
@@ -33,11 +33,12 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -270,8 +271,7 @@ class KandinskyV22Img2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
pipeline = KandinskyV22Img2ImgPipeline.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = pipeline.enable_model_cpu_offload()
|
||||
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
@@ -282,6 +282,7 @@ class KandinskyV22Img2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
negative_prompt="",
|
||||
).to_tuple()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipeline(
|
||||
image=init_image,
|
||||
image_embeds=image_emb,
|
||||
@@ -298,4 +299,5 @@ class KandinskyV22Img2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert image.shape == (768, 768, 3)
|
||||
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten())
|
||||
assert max_diff < 1e-4
|
||||
|
||||
@@ -34,12 +34,13 @@ from diffusers.utils.testing_utils import (
|
||||
is_flaky,
|
||||
load_image,
|
||||
load_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -338,6 +339,7 @@ class KandinskyV22InpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
negative_prompt="",
|
||||
).to_tuple()
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
output = pipeline(
|
||||
image=init_image,
|
||||
mask_image=mask,
|
||||
@@ -354,4 +356,5 @@ class KandinskyV22InpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert image.shape == (768, 768, 3)
|
||||
|
||||
assert_mean_pixel_difference(image, expected_image)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten())
|
||||
assert max_diff < 1e-4
|
||||
|
||||
@@ -124,7 +124,7 @@ class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
init_image = init_image.resize((64, 64), resample=PIL_INTERPOLATION["lanczos"])
|
||||
|
||||
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
|
||||
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution")
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
@@ -50,9 +50,11 @@ from diffusers.utils.testing_utils import (
|
||||
load_numpy,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate_version_greater,
|
||||
require_python39_or_higher,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
run_test_in_subprocess,
|
||||
skip_mps,
|
||||
slow,
|
||||
@@ -124,6 +126,8 @@ class StableDiffusionPipelineFastTests(
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
cross_attention_dim = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
@@ -134,7 +138,7 @@ class StableDiffusionPipelineFastTests(
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
@@ -158,11 +162,11 @@ class StableDiffusionPipelineFastTests(
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=64,
|
||||
hidden_size=cross_attention_dim,
|
||||
intermediate_size=16,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=3,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=2,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
@@ -210,7 +214,7 @@ class StableDiffusionPipelineFastTests(
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.3203, 0.4555, 0.4711, 0.3505, 0.3973, 0.4650, 0.5137, 0.3392, 0.4045])
|
||||
expected_slice = np.array([0.1763, 0.4776, 0.4986, 0.2566, 0.3802, 0.4596, 0.5363, 0.3277, 0.3949])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -230,7 +234,7 @@ class StableDiffusionPipelineFastTests(
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.3454, 0.5349, 0.5185, 0.2808, 0.4509, 0.4612, 0.4655, 0.3601, 0.4315])
|
||||
expected_slice = np.array([0.2368, 0.4900, 0.5019, 0.2723, 0.4473, 0.4578, 0.4551, 0.3532, 0.4133])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -252,7 +256,7 @@ class StableDiffusionPipelineFastTests(
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.3454, 0.5349, 0.5185, 0.2808, 0.4509, 0.4612, 0.4655, 0.3601, 0.4315])
|
||||
expected_slice = np.array([0.2368, 0.4900, 0.5019, 0.2723, 0.4473, 0.4578, 0.4551, 0.3532, 0.4133])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -371,12 +375,6 @@ class StableDiffusionPipelineFastTests(
|
||||
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
def test_ip_adapter_single(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array([0.3203, 0.4555, 0.4711, 0.3505, 0.3973, 0.4650, 0.5137, 0.3392, 0.4045])
|
||||
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_stable_diffusion_ddim_factor_8(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
@@ -392,7 +390,7 @@ class StableDiffusionPipelineFastTests(
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 136, 136, 3)
|
||||
expected_slice = np.array([0.4346, 0.5621, 0.5016, 0.3926, 0.4533, 0.4134, 0.5625, 0.5632, 0.5265])
|
||||
expected_slice = np.array([0.4720, 0.5426, 0.5160, 0.3961, 0.4696, 0.4296, 0.5738, 0.5888, 0.5481])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -410,7 +408,7 @@ class StableDiffusionPipelineFastTests(
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.3411, 0.5032, 0.4704, 0.3135, 0.4323, 0.4740, 0.5150, 0.3498, 0.4022])
|
||||
expected_slice = np.array([0.1941, 0.4748, 0.4880, 0.2222, 0.4221, 0.4545, 0.5604, 0.3488, 0.3902])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -450,7 +448,7 @@ class StableDiffusionPipelineFastTests(
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.3149, 0.5246, 0.4796, 0.3218, 0.4469, 0.4729, 0.5151, 0.3597, 0.3954])
|
||||
expected_slice = np.array([0.2681, 0.4785, 0.4857, 0.2426, 0.4473, 0.4481, 0.5610, 0.3676, 0.3855])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -469,7 +467,7 @@ class StableDiffusionPipelineFastTests(
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.3151, 0.5243, 0.4794, 0.3217, 0.4468, 0.4728, 0.5152, 0.3598, 0.3954])
|
||||
expected_slice = np.array([0.2682, 0.4782, 0.4855, 0.2424, 0.4472, 0.4479, 0.5612, 0.3676, 0.3854])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -488,7 +486,7 @@ class StableDiffusionPipelineFastTests(
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.3149, 0.5246, 0.4796, 0.3218, 0.4469, 0.4729, 0.5151, 0.3597, 0.3954])
|
||||
expected_slice = np.array([0.2681, 0.4785, 0.4857, 0.2426, 0.4473, 0.4481, 0.5610, 0.3676, 0.3855])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -560,7 +558,7 @@ class StableDiffusionPipelineFastTests(
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.3458, 0.5120, 0.4800, 0.3116, 0.4348, 0.4802, 0.5237, 0.3467, 0.3991])
|
||||
expected_slice = np.array([0.1907, 0.4709, 0.4858, 0.2224, 0.4223, 0.4539, 0.5606, 0.3489, 0.3900])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -1442,3 +1440,121 @@ class StableDiffusionPipelineNightlyTests(unittest.TestCase):
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
|
||||
# (sayakpaul): This test suite was run in the DGX with two GPUs (1, 2).
|
||||
@slow
|
||||
@require_torch_multi_gpu
|
||||
@require_accelerate_version_greater("0.27.0")
|
||||
class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, generator_device="cpu", seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a photograph of an astronaut riding a horse",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 50,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_pipeline_output_without_device_map(self):
|
||||
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=True)
|
||||
inputs = self.get_inputs()
|
||||
no_device_map_image = sd_pipe(**inputs).images
|
||||
|
||||
del sd_pipe
|
||||
|
||||
return no_device_map_image
|
||||
|
||||
def test_forward_pass_balanced_device_map(self):
|
||||
no_device_map_image = self.get_pipeline_output_without_device_map()
|
||||
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe_with_device_map.set_progress_bar_config(disable=True)
|
||||
inputs = self.get_inputs()
|
||||
device_map_image = sd_pipe_with_device_map(**inputs).images
|
||||
|
||||
max_diff = np.abs(device_map_image - no_device_map_image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_components_put_in_right_devices(self):
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
assert len(set(sd_pipe_with_device_map.hf_device_map.values())) >= 2
|
||||
|
||||
def test_max_memory(self):
|
||||
no_device_map_image = self.get_pipeline_output_without_device_map()
|
||||
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
device_map="balanced",
|
||||
max_memory={0: "1GB", 1: "1GB"},
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
sd_pipe_with_device_map.set_progress_bar_config(disable=True)
|
||||
inputs = self.get_inputs()
|
||||
device_map_image = sd_pipe_with_device_map(**inputs).images
|
||||
|
||||
max_diff = np.abs(device_map_image - no_device_map_image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_reset_device_map(self):
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe_with_device_map.reset_device_map()
|
||||
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
for name, component in sd_pipe_with_device_map.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
assert component.device.type == "cpu"
|
||||
|
||||
def test_reset_device_map_to(self):
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe_with_device_map.reset_device_map()
|
||||
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `to()` can be used and the pipeline can be called.
|
||||
pipe = sd_pipe_with_device_map.to("cuda")
|
||||
_ = pipe("hello", num_inference_steps=2)
|
||||
|
||||
def test_reset_device_map_enable_model_cpu_offload(self):
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe_with_device_map.reset_device_map()
|
||||
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `enable_model_cpu_offload()` can be used and the pipeline can be called.
|
||||
sd_pipe_with_device_map.enable_model_cpu_offload()
|
||||
_ = sd_pipe_with_device_map("hello", num_inference_steps=2)
|
||||
|
||||
def test_reset_device_map_enable_sequential_cpu_offload(self):
|
||||
sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
|
||||
)
|
||||
sd_pipe_with_device_map.reset_device_map()
|
||||
|
||||
assert sd_pipe_with_device_map.hf_device_map is None
|
||||
|
||||
# Make sure `enable_sequential_cpu_offload()` can be used and the pipeline can be called.
|
||||
sd_pipe_with_device_map.enable_sequential_cpu_offload()
|
||||
_ = sd_pipe_with_device_map("hello", num_inference_steps=2)
|
||||
|
||||
@@ -32,6 +32,7 @@ from diffusers import (
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import IPAdapterMixin
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.controlnet_xs import UNetControlNetXSModel
|
||||
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
|
||||
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
|
||||
from diffusers.models.unets.unet_motion_model import UNetMotionModel
|
||||
@@ -1685,7 +1686,10 @@ class PipelineTesterMixin:
|
||||
self.assertTrue(hasattr(pipe, "vae") and isinstance(pipe.vae, (AutoencoderKL, AutoencoderTiny)))
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "unet")
|
||||
and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel))
|
||||
and isinstance(
|
||||
pipe.unet,
|
||||
(UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel, UNetControlNetXSModel),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user